diff --git a/src/custom_timm/__pycache__/__init__.cpython-312.pyc b/src/custom_timm/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa98de193cb93fccfacbd1d2d6882ac643442335 Binary files /dev/null and b/src/custom_timm/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/custom_timm/__pycache__/version.cpython-312.pyc b/src/custom_timm/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfffe5047f781d3fafdbb937390929424f56c9ea Binary files /dev/null and b/src/custom_timm/__pycache__/version.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/__init__.cpython-312.pyc b/src/custom_timm/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa1c1e61328c0a137846a2f20438461568cf64a Binary files /dev/null and b/src/custom_timm/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/auto_augment.cpython-312.pyc b/src/custom_timm/data/__pycache__/auto_augment.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76a6f5b1c13d3fb0c49f358ed294c6e656ac81c4 Binary files /dev/null and b/src/custom_timm/data/__pycache__/auto_augment.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/config.cpython-312.pyc b/src/custom_timm/data/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b6ee41069013c0135b212f179f339561b14cce3 Binary files /dev/null and b/src/custom_timm/data/__pycache__/config.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/constants.cpython-312.pyc b/src/custom_timm/data/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa2ba6b69152a76d81e632c7ef4ec60a1222cc43 Binary files /dev/null and b/src/custom_timm/data/__pycache__/constants.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/dataset.cpython-312.pyc b/src/custom_timm/data/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccf291466ed2c3441ca807b66779ce6485b3a2e4 Binary files /dev/null and b/src/custom_timm/data/__pycache__/dataset.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/dataset_factory.cpython-312.pyc b/src/custom_timm/data/__pycache__/dataset_factory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ded6dc1c8c5efa5b8960950bac3881a1213f0de Binary files /dev/null and b/src/custom_timm/data/__pycache__/dataset_factory.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/distributed_sampler.cpython-312.pyc b/src/custom_timm/data/__pycache__/distributed_sampler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f1589724b789d596b0309c7b61a3524010e2ad5 Binary files /dev/null and b/src/custom_timm/data/__pycache__/distributed_sampler.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/loader.cpython-312.pyc b/src/custom_timm/data/__pycache__/loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ad4da038894fdb4c2012b23754e1e42085874d4 Binary files /dev/null and b/src/custom_timm/data/__pycache__/loader.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/mixup.cpython-312.pyc b/src/custom_timm/data/__pycache__/mixup.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..292559403c2838611967d989ec483d0d41d60dde Binary files /dev/null and b/src/custom_timm/data/__pycache__/mixup.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/random_erasing.cpython-312.pyc b/src/custom_timm/data/__pycache__/random_erasing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b171a9cb37676cc402c4a402f75657841bbe72d Binary files /dev/null and b/src/custom_timm/data/__pycache__/random_erasing.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/real_labels.cpython-312.pyc b/src/custom_timm/data/__pycache__/real_labels.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb9275386212b73f060def8f89d5955e5f7441e Binary files /dev/null and b/src/custom_timm/data/__pycache__/real_labels.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/transforms.cpython-312.pyc b/src/custom_timm/data/__pycache__/transforms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7ef4587f61d3259244f8ac14de9b8104001bfa9 Binary files /dev/null and b/src/custom_timm/data/__pycache__/transforms.cpython-312.pyc differ diff --git a/src/custom_timm/data/__pycache__/transforms_factory.cpython-312.pyc b/src/custom_timm/data/__pycache__/transforms_factory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bca36da8902ed4accf33e7f22eb62838e986df7e Binary files /dev/null and b/src/custom_timm/data/__pycache__/transforms_factory.cpython-312.pyc differ diff --git a/src/custom_timm/data/parsers/__init__.py b/src/custom_timm/data/parsers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e820d5e027ba82c937829ad50b2b2c9a97d2f28 --- /dev/null +++ b/src/custom_timm/data/parsers/__init__.py @@ -0,0 +1,2 @@ +from .parser_factory import create_parser +from .img_extensions import * diff --git a/src/custom_timm/data/parsers/__pycache__/class_map.cpython-312.pyc b/src/custom_timm/data/parsers/__pycache__/class_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d660480360934ef44d06c987475fabfbbdfd276 Binary files /dev/null and b/src/custom_timm/data/parsers/__pycache__/class_map.cpython-312.pyc differ diff --git a/src/custom_timm/data/parsers/class_map.py b/src/custom_timm/data/parsers/class_map.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf3f57e014566e165374acae8dec031c02048f8 --- /dev/null +++ b/src/custom_timm/data/parsers/class_map.py @@ -0,0 +1,22 @@ +import os +import pickle + +def load_class_map(map_or_filename, root=''): + if isinstance(map_or_filename, dict): + assert dict, 'class_map dict must be non-empty' + return map_or_filename + class_map_path = map_or_filename + if not os.path.exists(class_map_path): + class_map_path = os.path.join(root, class_map_path) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename + class_map_ext = os.path.splitext(map_or_filename)[-1].lower() + if class_map_ext == '.txt': + with open(class_map_path) as f: + class_to_idx = {v.strip(): k for k, v in enumerate(f)} + elif class_map_ext == '.pkl': + with open(class_map_path,'rb') as f: + class_to_idx = pickle.load(f) + else: + assert False, f'Unsupported class map file extension ({class_map_ext}).' + return class_to_idx + diff --git a/src/custom_timm/data/parsers/img_extensions.py b/src/custom_timm/data/parsers/img_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..45c85aabd00ca5ebf7bd6fa85c674570fe60f9c8 --- /dev/null +++ b/src/custom_timm/data/parsers/img_extensions.py @@ -0,0 +1,50 @@ +from copy import deepcopy + +__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] + + +IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use +_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync + + +def _set_extensions(extensions): + global IMG_EXTENSIONS + global _IMG_EXTENSIONS_SET + dedupe = set() # NOTE de-duping tuple while keeping original order + IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) + _IMG_EXTENSIONS_SET = set(extensions) + + +def _valid_extension(x: str): + return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') + + +def is_img_extension(ext): + return ext in _IMG_EXTENSIONS_SET + + +def get_img_extensions(as_set=False): + return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) + + +def set_img_extensions(extensions): + assert len(extensions) + for x in extensions: + assert _valid_extension(x) + _set_extensions(extensions) + + +def add_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + for x in ext: + assert _valid_extension(x) + extensions = IMG_EXTENSIONS + tuple(ext) + _set_extensions(extensions) + + +def del_img_extensions(ext): + if not isinstance(ext, (list, tuple, set)): + ext = (ext,) + extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) + _set_extensions(extensions) diff --git a/src/custom_timm/data/parsers/parser.py b/src/custom_timm/data/parsers/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..76ab6d18283644702424d0ff2af5832d6d6dd3b7 --- /dev/null +++ b/src/custom_timm/data/parsers/parser.py @@ -0,0 +1,17 @@ +from abc import abstractmethod + + +class Parser: + def __init__(self): + pass + + @abstractmethod + def _filename(self, index, basename=False, absolute=False): + pass + + def filename(self, index, basename=False, absolute=False): + return self._filename(index, basename=basename, absolute=absolute) + + def filenames(self, basename=False, absolute=False): + return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] + diff --git a/src/custom_timm/data/parsers/parser_factory.py b/src/custom_timm/data/parsers/parser_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..0665c02a8b4db12b8ac6b7095999751c5b26f384 --- /dev/null +++ b/src/custom_timm/data/parsers/parser_factory.py @@ -0,0 +1,28 @@ +import os + +from .parser_image_folder import ParserImageFolder +from .parser_image_in_tar import ParserImageInTar + + +def create_parser(name, root, split='train', **kwargs): + name = name.lower() + name = name.split('/', 2) + prefix = '' + if len(name) > 1: + prefix = name[0] + name = name[-1] + + # FIXME improve the selection right now just tfds prefix or fallback path, will need options to + # explicitly select other options shortly + if prefix == 'tfds': + from .parser_tfds import ParserTfds # defer tensorflow import + parser = ParserTfds(root, name, split=split, **kwargs) + else: + assert os.path.exists(root) + # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder + # FIXME support split here, in parser? + if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': + parser = ParserImageInTar(root, **kwargs) + else: + parser = ParserImageFolder(root, **kwargs) + return parser diff --git a/src/custom_timm/data/parsers/parser_image_folder.py b/src/custom_timm/data/parsers/parser_image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..d82b024377e99a26fb87c92256a076505d894666 --- /dev/null +++ b/src/custom_timm/data/parsers/parser_image_folder.py @@ -0,0 +1,90 @@ +""" A dataset parser that reads images from folders + +Folders are scannerd recursively to find image files. Labels are based +on the folder hierarchy, just leaf folders by default. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +from typing import Dict, List, Optional, Set, Tuple, Union + +from custom_timm.utils.misc import natural_key + +from .class_map import load_class_map +from .img_extensions import get_img_extensions +from .parser import Parser + + +def find_images_and_targets( + folder: str, + types: Optional[Union[List, Tuple, Set]] = None, + class_to_idx: Optional[Dict] = None, + leaf_name_only: bool = True, + sort: bool = True +): + """ Walk folder recursively to discover images and map them to classes by folder names. + + Args: + folder: root of folder to recrusively search + types: types (file extensions) to search for in path + class_to_idx: specify mapping for class (folder name) to class index if set + leaf_name_only: use only leaf-name of folder walk for class names + sort: re-sort found images by name (for consistent ordering) + + Returns: + A list of image and target tuples, class_to_idx mapping + """ + types = get_img_extensions(as_set=True) if not types else set(types) + labels = [] + filenames = [] + for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): + rel_path = os.path.relpath(root, folder) if (root != folder) else '' + label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') + for f in files: + base, ext = os.path.splitext(f) + if ext.lower() in types: + filenames.append(os.path.join(root, f)) + labels.append(label) + if class_to_idx is None: + # building class index + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] + if sort: + images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) + return images_and_targets, class_to_idx + + +class ParserImageFolder(Parser): + + def __init__( + self, + root, + class_map=''): + super().__init__() + + self.root = root + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) + if len(self.samples) == 0: + raise RuntimeError( + f'Found 0 images in subfolders of {root}. ' + f'Supported image extensions are {", ".join(get_img_extensions())}') + + def __getitem__(self, index): + path, target = self.samples[index] + return open(path, 'rb'), target + + def __len__(self): + return len(self.samples) + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0] + if basename: + filename = os.path.basename(filename) + elif not absolute: + filename = os.path.relpath(filename, self.root) + return filename diff --git a/src/custom_timm/data/parsers/parser_image_in_tar.py b/src/custom_timm/data/parsers/parser_image_in_tar.py new file mode 100644 index 0000000000000000000000000000000000000000..7d3c1765b5bd3809f93a5c1707b472f7f54e5eb7 --- /dev/null +++ b/src/custom_timm/data/parsers/parser_image_in_tar.py @@ -0,0 +1,229 @@ +""" A dataset parser that reads tarfile based datasets + +This parser can read and extract image samples from: +* a single tar of image files +* a folder of multiple tarfiles containing imagefiles +* a tar of tars containing image files + +Labels are based on the combined folder and/or tar name structure. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +import pickle +import tarfile +from glob import glob +from typing import List, Tuple, Dict, Set, Optional, Union + +import numpy as np + +from custom_timm.utils.misc import natural_key + +from .class_map import load_class_map +from .img_extensions import get_img_extensions +from .parser import Parser + +_logger = logging.getLogger(__name__) +CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' + + +class TarState: + + def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None): + self.tf: tarfile.TarFile = tf + self.ti: tarfile.TarInfo = ti + self.children: Dict[str, TarState] = {} # child states (tars within tars) + + def reset(self): + self.tf = None + + +def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]): + sample_count = 0 + for i, ti in enumerate(tf): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + name, ext = os.path.splitext(basename) + ext = ext.lower() + if ext == '.tar': + with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf: + child_info = dict( + name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[]) + sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions) + _logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.') + parent_info['children'].append(child_info) + elif ext in extensions: + parent_info['samples'].append(ti) + sample_count += 1 + return sample_count + + +def extract_tarinfos( + root, + class_name_to_idx: Optional[Dict] = None, + cache_tarinfo: Optional[bool] = None, + extensions: Optional[Union[List, Tuple, Set]] = None, + sort: bool = True +): + extensions = get_img_extensions(as_set=True) if not extensions else set(extensions) + root_is_tar = False + if os.path.isfile(root): + assert os.path.splitext(root)[-1].lower() == '.tar' + tar_filenames = [root] + root, root_name = os.path.split(root) + root_name = os.path.splitext(root_name)[0] + root_is_tar = True + else: + root_name = root.strip(os.path.sep).split(os.path.sep)[-1] + tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True) + num_tars = len(tar_filenames) + tar_bytes = sum([os.path.getsize(f) for f in tar_filenames]) + assert num_tars, f'No .tar files found at specified path ({root}).' + + _logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...') + info = dict(tartrees=[]) + cache_path = '' + if cache_tarinfo is None: + cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB + if cache_tarinfo: + cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX + cache_path = os.path.join(root, cache_filename) + if os.path.exists(cache_path): + _logger.info(f'Reading tar info from cache file {cache_path}.') + with open(cache_path, 'rb') as pf: + info = pickle.load(pf) + assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles" + else: + for i, fn in enumerate(tar_filenames): + path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0] + with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode + parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[]) + num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions) + num_children = len(parent_info["children"]) + _logger.debug( + f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.') + info['tartrees'].append(parent_info) + if cache_path: + _logger.info(f'Writing tar info to cache file {cache_path}.') + with open(cache_path, 'wb') as pf: + pickle.dump(info, pf) + + samples = [] + labels = [] + build_class_map = False + if class_name_to_idx is None: + build_class_map = True + + # Flatten tartree info into lists of samples and targets w/ targets based on label id via + # class map arg or from unique paths. + # NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children + # this covers my current use cases and keeps things a little easier to test for now. + tarfiles = [] + + def _label_from_paths(*path, leaf_only=True): + path = os.path.join(*path).strip(os.path.sep) + return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_') + + def _add_samples(info, fn): + added = 0 + for s in info['samples']: + label = _label_from_paths(info['path'], os.path.dirname(s.path)) + if not build_class_map and label not in class_name_to_idx: + continue + samples.append((s, fn, info['ti'])) + labels.append(label) + added += 1 + return added + + _logger.info(f'Collecting samples and building tar states.') + for parent_info in info['tartrees']: + # if tartree has children, we assume all samples are at the child level + tar_name = None if root_is_tar else parent_info['name'] + tar_state = TarState() + parent_added = 0 + for child_info in parent_info['children']: + child_added = _add_samples(child_info, fn=tar_name) + if child_added: + tar_state.children[child_info['name']] = TarState(ti=child_info['ti']) + parent_added += child_added + parent_added += _add_samples(parent_info, fn=tar_name) + if parent_added: + tarfiles.append((tar_name, tar_state)) + del info + + if build_class_map: + # build class index + sorted_labels = list(sorted(set(labels), key=natural_key)) + class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + + _logger.info(f'Mapping targets and sorting samples.') + samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx] + if sort: + samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path)) + samples, targets = zip(*samples_and_targets) + samples = np.array(samples) + targets = np.array(targets) + _logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.') + return samples, targets, class_name_to_idx, tarfiles + + +class ParserImageInTar(Parser): + """ Multi-tarfile dataset parser where there is one .tar file per class + """ + + def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None): + super().__init__() + + class_name_to_idx = None + if class_map: + class_name_to_idx = load_class_map(class_map, root) + self.root = root + self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( + self.root, + class_name_to_idx=class_name_to_idx, + cache_tarinfo=cache_tarinfo + ) + self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} + if len(tarfiles) == 1 and tarfiles[0][0] is None: + self.root_is_tar = True + self.tar_state = tarfiles[0][1] + else: + self.root_is_tar = False + self.tar_state = dict(tarfiles) + self.cache_tarfiles = cache_tarfiles + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + sample = self.samples[index] + target = self.targets[index] + sample_ti, parent_fn, child_ti = sample + parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root + + tf = None + cache_state = None + if self.cache_tarfiles: + cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn] + tf = cache_state.tf + if tf is None: + tf = tarfile.open(parent_abs) + if self.cache_tarfiles: + cache_state.tf = tf + if child_ti is not None: + ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None + if ctf is None: + ctf = tarfile.open(fileobj=tf.extractfile(child_ti)) + if self.cache_tarfiles: + cache_state.children[child_ti.name].tf = ctf + tf = ctf + + return tf.extractfile(sample_ti), target + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0].name + if basename: + filename = os.path.basename(filename) + return filename diff --git a/src/custom_timm/data/parsers/parser_image_tar.py b/src/custom_timm/data/parsers/parser_image_tar.py new file mode 100644 index 0000000000000000000000000000000000000000..c5520ee64c1d798a37d45b5361ab3b800f5adbe6 --- /dev/null +++ b/src/custom_timm/data/parsers/parser_image_tar.py @@ -0,0 +1,74 @@ +""" A dataset parser that reads single tarfile based datasets + +This parser can read datasets consisting if a single tarfile containing images. +I am planning to deprecated it in favour of ParerImageInTar. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import tarfile + +from custom_timm.utils.misc import natural_key + +from .class_map import load_class_map +from .img_extensions import get_img_extensions +from .parser import Parser + + +def extract_tarinfo(tarfile, class_to_idx=None, sort=True): + extensions = get_img_extensions(as_set=True) + files = [] + labels = [] + for ti in tarfile.getmembers(): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + label = os.path.basename(dirname) + ext = os.path.splitext(basename)[1] + if ext.lower() in extensions: + files.append(ti) + labels.append(label) + if class_to_idx is None: + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] + if sort: + tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) + return tarinfo_and_targets, class_to_idx + + +class ParserImageTar(Parser): + """ Single tarfile dataset where classes are mapped to folders within tar + NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can + operate on folders of tars or tars in tars. + """ + def __init__(self, root, class_map=''): + super().__init__() + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + assert os.path.isfile(root) + self.root = root + + with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later + self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) + self.imgs = self.samples + self.tarfile = None # lazy init in __getitem__ + + def __getitem__(self, index): + if self.tarfile is None: + self.tarfile = tarfile.open(self.root) + tarinfo, target = self.samples[index] + fileobj = self.tarfile.extractfile(tarinfo) + return fileobj, target + + def __len__(self): + return len(self.samples) + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0].name + if basename: + filename = os.path.basename(filename) + return filename diff --git a/src/custom_timm/data/parsers/parser_tfds.py b/src/custom_timm/data/parsers/parser_tfds.py new file mode 100644 index 0000000000000000000000000000000000000000..739f3813d0ad20bcb92676662dad62d53be1fe70 --- /dev/null +++ b/src/custom_timm/data/parsers/parser_tfds.py @@ -0,0 +1,301 @@ +""" Dataset parser interface that wraps TFDS datasets + +Wraps many (most?) TFDS image-classification datasets +from https://github.com/tensorflow/datasets +https://www.tensorflow.org/datasets/catalog/overview#image_classification + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import torch +import torch.distributed as dist +from PIL import Image + +try: + import tensorflow as tf + tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) + import tensorflow_datasets as tfds + try: + tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg + has_buggy_even_splits = False + except TypeError: + print("Warning: This version of tfds doesn't have the latest even_splits impl. " + "Please update or use tfds-nightly for better fine-grained split behaviour.") + has_buggy_even_splits = True + # NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults) + # import resource + # low, high = resource.getrlimit(resource.RLIMIT_NOFILE) + # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) +except ImportError as e: + print(e) + print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") + exit(1) +from .parser import Parser + + +MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities +SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue +PREFETCH_SIZE = 2048 # examples to prefetch + + +def even_split_indices(split, n, num_examples): + partitions = [round(i * num_examples / n) for i in range(n + 1)] + return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)] + + +def get_class_labels(info): + if 'label' not in info.features: + return {} + class_label = info.features['label'] + class_to_idx = {n: class_label.str2int(n) for n in class_label.names} + return class_to_idx + + +class ParserTfds(Parser): + """ Wrap Tensorflow Datasets for use in PyTorch + + There several things to be aware of: + * To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of + dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last + https://github.com/pytorch/pytorch/issues/33413 + * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch + from each worker could be a different size. For training this is worked around by option above, for + validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced + across replicas are of same size. This will slightly alter the results, distributed validation will not be + 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse + since there are up to N * J extra examples with IterableDatasets. + * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of + replicas and dataloader workers you can use. For really small datasets that only contain a few shards + you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the + benefit of distributed training or fast dataloading should be much less for small datasets. + * This wrapper is currently configured to return individual, decompressed image examples from the TFDS + dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible + to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream + components. + + """ + + def __init__( + self, + root, + name, + split='train', + is_training=False, + batch_size=None, + download=False, + repeats=0, + seed=42, + input_name='image', + input_image='RGB', + target_name='label', + target_image='', + prefetch_size=None, + shuffle_size=None, + max_threadpool_size=None + ): + """ Tensorflow-datasets Wrapper + + Args: + root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir) + name: tfds dataset name (eg `imagenet2012`) + split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) + is_training: training mode, shuffle enabled, dataset len rounded by batch_size + batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes + download: download and build TFDS dataset if set, otherwise must use tfds CLI + repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) + seed: common seed for shard shuffle across all distributed/worker instances + input_name: name of Feature to return as data (input) + input_image: image mode if input is an image (currently PIL mode string) + target_name: name of Feature to return as target (label) + target_image: image mode if target is an image (currently PIL mode string) + prefetch_size: override default tf.data prefetch buffer size + shuffle_size: override default tf.data shuffle buffer size + max_threadpool_size: override default threadpool size for tf.data + """ + super().__init__() + self.root = root + self.split = split + self.is_training = is_training + if self.is_training: + assert batch_size is not None, \ + "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" + self.batch_size = batch_size + self.repeats = repeats + self.common_seed = seed # a seed that's fixed across all worker / distributed instances + + # performance settings + self.prefetch_size = prefetch_size or PREFETCH_SIZE + self.shuffle_size = shuffle_size or SHUFFLE_SIZE + self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE + + # TFDS builder and split information + self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature + self.input_image = input_image + self.target_name = target_name + self.target_image = target_image + self.builder = tfds.builder(name, data_dir=root) + # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag + if download: + self.builder.download_and_prepare() + self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} + self.split_info = self.builder.info.splits[split] + self.num_examples = self.split_info.num_examples + + # Distributed world state + self.dist_rank = 0 + self.dist_num_replicas = 1 + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + self.dist_rank = dist.get_rank() + self.dist_num_replicas = dist.get_world_size() + + # Attributes that are updated in _lazy_init, including the tf.data pipeline itself + self.global_num_workers = 1 + self.worker_info = None + self.worker_seed = 0 # seed unique to each work instance + self.subsplit = None # set when data is distributed across workers using sub-splits + self.ds = None # initialized lazily on each dataloader worker process + + def _lazy_init(self): + """ Lazily initialize the dataset. + + This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that + will be using the dataset instance. The __init__ method is called on the main process, + this will be called in a dataloader worker process. + + NOTE: There will be problems if you try to re-use this dataset across different loader/worker + instances once it has been initialized. Do not call any dataset methods that can call _lazy_init + before it is passed to dataloader. + """ + worker_info = torch.utils.data.get_worker_info() + + # setup input context to split dataset across distributed processes + num_workers = 1 + global_worker_id = 0 + if worker_info is not None: + self.worker_info = worker_info + self.worker_seed = worker_info.seed + num_workers = worker_info.num_workers + self.global_num_workers = self.dist_num_replicas * num_workers + global_worker_id = self.dist_rank * num_workers + worker_info.id + + """ Data sharding + InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. + My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) + between the splits each iteration, but that understanding could be wrong. + + I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing + the data across workers. For training InputContext is used to assign shards to nodes unless num_shards + in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or + for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding. + """ + should_subsplit = self.global_num_workers > 1 and ( + self.split_info.num_shards < self.global_num_workers or not self.is_training) + if should_subsplit: + # split the dataset w/o using sharding for more even examples / worker, can result in less optimal + # read patterns for distributed training (overlap across shards) so better to use InputContext there + if has_buggy_even_splits: + # my even_split workaround doesn't work on subsplits, upgrade tfds! + if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): + subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples) + self.subsplit = subsplits[global_worker_id] + else: + subsplits = tfds.even_splits(self.split, self.global_num_workers) + self.subsplit = subsplits[global_worker_id] + + input_context = None + if self.global_num_workers > 1 and self.subsplit is None: + # set input context to divide shards among distributed replicas + input_context = tf.distribute.InputContext( + num_input_pipelines=self.global_num_workers, + input_pipeline_id=global_worker_id, + num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? + ) + read_config = tfds.ReadConfig( + shuffle_seed=self.common_seed, + shuffle_reshuffle_each_iteration=True, + input_context=input_context) + ds = self.builder.as_dataset( + split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) + # avoid overloading threading w/ combo of TF ds threads + PyTorch workers + options = tf.data.Options() + thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' + getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers) + getattr(options, thread_member).max_intra_op_parallelism = 1 + ds = ds.with_options(options) + if self.is_training or self.repeats > 1: + # to prevent excessive drop_last batch behaviour w/ IterableDatasets + # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading + ds = ds.repeat() # allow wrap around and break iteration manually + if self.is_training: + ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) + self.ds = tfds.as_numpy(ds) + + def __iter__(self): + if self.ds is None: + self._lazy_init() + + # Compute a rounded up sample count that is used to: + # 1. make batches even cross workers & replicas in distributed validation. + # This adds extra examples and will slightly alter validation results. + # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size + # batches are produced (underlying tfds iter wraps around) + target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers) + if self.is_training: + # round up to nearest batch_size per worker-replica + target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size + + # Iterate until exhausted or sample count hits target when training (ds.repeat enabled) + example_count = 0 + for example in self.ds: + input_data = example[self.input_name] + if self.input_image: + input_data = Image.fromarray(input_data, mode=self.input_image) + target_data = example[self.target_name] + if self.target_image: + target_data = Image.fromarray(target_data, mode=self.target_image) + yield input_data, target_data + example_count += 1 + if self.is_training and example_count >= target_example_count: + # Need to break out of loop when repeat() is enabled for training w/ oversampling + # this results in extra examples per epoch but seems more desirable than dropping + # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) + break + + # Pad across distributed nodes (make counts equal by adding examples) + if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ + 0 < example_count < target_example_count: + # Validation batch padding only done for distributed training where results are reduced across nodes. + # For single process case, it won't matter if workers return different batch sizes. + # If using input_context or % based splits, sample count can vary significantly across workers and this + # approach should not be used (hence disabled if self.subsplit isn't set). + while example_count < target_example_count: + yield input_data, target_data # yield prev sample again + example_count += 1 + + def __len__(self): + # this is just an estimate and does not factor in extra examples added to pad batches based on + # complete worker & replica info (not available until init in dataloader). + return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas) + + def _filename(self, index, basename=False, absolute=False): + assert False, "Not supported" # no random access to examples + + def filenames(self, basename=False, absolute=False): + """ Return all filenames in dataset, overrides base""" + if self.ds is None: + self._lazy_init() + names = [] + for sample in self.ds: + if len(names) > self.num_examples: + break # safety for ds.repeat() case + if 'file_name' in sample: + name = sample['file_name'] + elif 'filename' in sample: + name = sample['filename'] + elif 'id' in sample: + name = sample['id'] + else: + assert False, "No supported name field present" + names.append(name) + return names diff --git a/src/custom_timm/models/gluon_resnet.py b/src/custom_timm/models/gluon_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f24eb3e682bc09df9434ba3bdf0248f303095f6f --- /dev/null +++ b/src/custom_timm/models/gluon_resnet.py @@ -0,0 +1,245 @@ +"""Pytorch impl of MxNet Gluon ResNet/(SE)ResNeXt variants +This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-ResNeXt additions +and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) +by Ross Wightman +""" + +from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SEModule +from .registry import register_model +from .resnet import ResNet, Bottleneck, BasicBlock + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'gluon_resnet18_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'), + 'gluon_resnet34_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'), + 'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'), + 'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'), + 'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'), + 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth', + first_conv='conv1.0'), + 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), + 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), + 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), + 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), + 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), + 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), + 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth', + first_conv='conv1.0'), +} + + +def _create_resnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) + + +@register_model +def gluon_resnet18_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('gluon_resnet18_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet34_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet34_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet50_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('gluon_resnet101_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('gluon_resnet152_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet50_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet101_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet152_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnet101_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1s', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_64x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext50_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt50-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_64x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-64x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_senet154(pretrained=False, **kwargs): + """Constructs an SENet-154 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_senet154', pretrained, **model_args) diff --git a/src/custom_timm/models/gluon_xception.py b/src/custom_timm/models/gluon_xception.py new file mode 100644 index 0000000000000000000000000000000000000000..809251b28dbecf867169010ac962a5fb5ca09e8d --- /dev/null +++ b/src/custom_timm/models/gluon_xception.py @@ -0,0 +1,267 @@ +"""Pytorch impl of Gluon Xception +This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl. + +Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html) +Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier, get_padding +from .registry import register_model + +__all__ = ['Xception65'] + +default_cfgs = { + 'gluon_xception65': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', + 'input_size': (3, 299, 299), + 'crop_pct': 0.903, + 'pool_size': (10, 10), + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + }, +} + +""" PADDING NOTES +The original PyTorch and Gluon impl of these models dutifully reproduced the +aligned padding added to Tensorflow models for Deeplab. This padding was compensating +for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to. +""" + + +class SeparableConv2d(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None): + super(SeparableConv2d, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + padding = get_padding(kernel_size, stride, dilation) + self.conv_dw = nn.Conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=inplanes, bias=bias) + self.bn = norm_layer(num_features=inplanes) + # pointwise convolution + self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn(x) + x = self.conv_pw(x) + return x + + +class Block(nn.Module): + def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None): + super(Block, self).__init__() + if isinstance(planes, (list, tuple)): + assert len(planes) == 3 + else: + planes = (planes,) * 3 + outplanes = planes[-1] + + if outplanes != inplanes or stride != 1: + self.skip = nn.Sequential() + self.skip.add_module('conv1', nn.Conv2d( + inplanes, outplanes, 1, stride=stride, bias=False)), + self.skip.add_module('bn1', norm_layer(num_features=outplanes)) + else: + self.skip = None + + rep = OrderedDict() + for i in range(3): + rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) + rep['conv%d' % (i + 1)] = SeparableConv2d( + inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer) + rep['bn%d' % (i + 1)] = norm_layer(planes[i]) + inplanes = planes[i] + + if not start_with_relu: + del rep['act1'] + else: + rep['act1'] = nn.ReLU(inplace=False) + self.rep = nn.Sequential(rep) + + def forward(self, x): + skip = x + if self.skip is not None: + skip = self.skip(skip) + x = self.rep(x) + skip + return x + + +class Xception65(nn.Module): + """Modified Aligned Xception. + + NOTE: only the 65 layer version is included here, the 71 layer variant + was not correct and had no pretrained weights + """ + + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, + drop_rate=0., global_pool='avg'): + super(Xception65, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + if output_stride == 32: + entry_block3_stride = 2 + exit_block20_stride = 2 + middle_dilation = 1 + exit_dilation = (1, 1) + elif output_stride == 16: + entry_block3_stride = 2 + exit_block20_stride = 1 + middle_dilation = 1 + exit_dilation = (1, 2) + elif output_stride == 8: + entry_block3_stride = 1 + exit_block20_stride = 1 + middle_dilation = 2 + exit_dilation = (2, 4) + else: + raise NotImplementedError + + # Entry flow + self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(num_features=32) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = norm_layer(num_features=64) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer) + self.block1_act = nn.ReLU(inplace=True) + self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer) + self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer) + + # Middle flow + self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( + 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)])) + + # Exit flow + self.block20 = Block( + 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer) + self.block20_act = nn.ReLU(inplace=True) + + self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn3 = norm_layer(num_features=1536) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn4 = norm_layer(num_features=1536) + self.act4 = nn.ReLU(inplace=True) + + self.num_features = 2048 + self.conv5 = SeparableConv2d( + 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn5 = norm_layer(num_features=self.num_features) + self.act5 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block1_act'), + dict(num_chs=256, reduction=8, module='block3.rep.act1'), + dict(num_chs=728, reduction=16, module='block20.rep.act1'), + dict(num_chs=2048, reduction=32, module='act5'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv[12]|bn[12]', + blocks=[ + (r'^mid\.block(\d+)', None), + (r'^block(\d+)', None), + (r'^conv[345]|bn[345]', (99,)), + ], + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "gradient checkpointing not supported" + + @torch.jit.ignore + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block1_act(x) + # c1 = x + x = self.block2(x) + # c2 = x + x = self.block3(x) + + # Middle flow + x = self.mid(x) + # c3 = x + + # Exit flow + x = self.block20(x) + x = self.block20_act(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + + x = self.conv5(x) + x = self.bn5(x) + x = self.act5(x) + return x + + def forward_head(self, x): + x = self.global_pool(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) + x = self.fc(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_gluon_xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception65, variant, pretrained, + feature_cfg=dict(feature_cls='hook'), + **kwargs) + + +@register_model +def gluon_xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + return _create_gluon_xception('gluon_xception65', pretrained, **kwargs) diff --git a/src/custom_timm/models/hardcorenas.py b/src/custom_timm/models/hardcorenas.py new file mode 100644 index 0000000000000000000000000000000000000000..e53134b3235feffe24fedbe451e1680cbcfed27e --- /dev/null +++ b/src/custom_timm/models/hardcorenas.py @@ -0,0 +1,151 @@ +from functools import partial + +import torch.nn as nn + +from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from .helpers import build_model_with_cfg, pretrained_cfg_for_features +from .layers import get_act_fn +from .mobilenetv3 import MobileNetV3, MobileNetV3Features +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hardcorenas_a': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'), + 'hardcorenas_b': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'), + 'hardcorenas_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'), + 'hardcorenas_d': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'), + 'hardcorenas_e': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'), + 'hardcorenas_f': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'), +} + + +def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): + """Creates a hardcorenas model + + Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS + Paper: https://arxiv.org/abs/2102.11646 + + """ + num_features = 1280 + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=32, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=se_layer, + **kwargs, + ) + + features_only = False + model_cls = MobileNetV3 + kwargs_filter = None + if model_kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, variant, pretrained, + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) + if features_only: + model.default_cfg = pretrained_cfg_for_features(model.default_cfg) + return model + + +@register_model +def hardcorenas_a(pretrained=False, **kwargs): + """ hardcorenas_A """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_b(pretrained=False, **kwargs): + """ hardcorenas_B """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], + ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_c(pretrained=False, **kwargs): + """ hardcorenas_C """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', + 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_d(pretrained=False, **kwargs): + """ hardcorenas_D """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], + ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_e(pretrained=False, **kwargs): + """ hardcorenas_E """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', + 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_f(pretrained=False, **kwargs): + """ hardcorenas_F """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k3_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs) + return model diff --git a/src/custom_timm/models/helpers.py b/src/custom_timm/models/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..d68c7e6541ae5f39af0d962ff3b453e4b0c266c4 --- /dev/null +++ b/src/custom_timm/models/helpers.py @@ -0,0 +1,796 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import collections.abc +import logging +import math +import os +import re +from collections import OrderedDict, defaultdict +from copy import deepcopy +from itertools import chain +from typing import Any, Callable, Optional, Tuple, Dict, Union + +import torch +import torch.nn as nn +from torch.hub import load_state_dict_from_url +from torch.utils.checkpoint import checkpoint + +from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .fx_features import FeatureGraphNet +from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf +from .layers import Conv2dSame, Linear, BatchNormAct2d +from .registry import get_pretrained_cfg + + +_logger = logging.getLogger(__name__) + + +# Global variables for rarely used pretrained checkpoint download progress and hash check. +# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. +_DOWNLOAD_PROGRESS = False +_CHECK_HASH = False + + +def clean_state_dict(state_dict): + # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training + cleaned_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k + cleaned_state_dict[name] = v + return cleaned_state_dict + + +def load_state_dict(checkpoint_path, use_ema=True): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = '' + if isinstance(checkpoint, dict): + if use_ema and checkpoint.get('state_dict_ema', None) is not None: + state_dict_key = 'state_dict_ema' + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + model.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + state_dict = clean_state_dict(checkpoint['state_dict']) + model.load_state_dict(state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def _resolve_pretrained_source(pretrained_cfg): + cfg_source = pretrained_cfg.get('source', '') + pretrained_url = pretrained_cfg.get('url', None) + pretrained_file = pretrained_cfg.get('file', None) + hf_hub_id = pretrained_cfg.get('hf_hub_id', None) + # resolve where to load pretrained weights from + load_from = '' + pretrained_loc = '' + if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): + # hf-hub specified as source via model identifier + load_from = 'hf-hub' + assert hf_hub_id + pretrained_loc = hf_hub_id + else: + # default source == timm or unspecified + if pretrained_file: + load_from = 'file' + pretrained_loc = pretrained_file + elif pretrained_url: + load_from = 'url' + pretrained_loc = pretrained_url + elif hf_hub_id and has_hf_hub(necessary=True): + # hf-hub available as alternate weight source in default_cfg + load_from = 'hf-hub' + pretrained_loc = hf_hub_id + if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg: + # if a filename override is set, return tuple for location w/ (hub_id, filename) + pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] + return load_from, pretrained_loc + + +def set_pretrained_download_progress(enable=True): + """ Set download progress for pretrained weights on/off (globally). """ + global _DOWNLOAD_PROGRESS + _DOWNLOAD_PROGRESS = enable + + +def set_pretrained_check_hash(enable=True): + """ Set hash checking for pretrained weights on/off (globally). """ + global _CHECK_HASH + _CHECK_HASH = enable + + +def load_custom_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + load_fn: Optional[Callable] = None, +): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + pretrained_cfg (dict): Default pretrained model cfg + load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if not load_from: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if load_from == 'hf-hub': # FIXME + _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") + elif load_from == 'url': + pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS) + + if load_fn is not None: + load_fn(model, pretrained_loc) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(pretrained_loc) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + num_classes: int = 1000, + in_chans: int = 3, + filter_fn: Optional[Callable] = None, + strict: bool = True, +): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset + num_classes (int): num_classes for model + in_chans (int): in_chans for model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + + """ + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if load_from == 'file': + _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') + state_dict = load_state_dict(pretrained_loc) + elif load_from == 'url': + _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') + state_dict = load_state_dict_from_url( + pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) + elif load_from == 'hf-hub': + _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') + if isinstance(pretrained_loc, (list, tuple)): + state_dict = load_state_dict_from_hf(*pretrained_loc) + else: + state_dict = load_state_dict_from_hf(pretrained_loc) + else: + _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") + return + + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = pretrained_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = pretrained_cfg.get('classifier', None) + label_offset = pretrained_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != pretrained_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + state_dict.pop(classifier_name + '.weight', None) + state_dict.pop(classifier_name + '.bias', None) + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + elif isinstance(old_module, BatchNormAct2d): + new_bn = BatchNormAct2d( + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + new_bn.drop = old_module.drop + new_bn.act = old_module.act + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) + + +def pretrained_cfg_for_features(pretrained_cfg): + pretrained_cfg = deepcopy(pretrained_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + pretrained_cfg.pop(tr, None) + return pretrained_cfg + + +def set_default_kwargs(kwargs, names, pretrained_cfg): + for n in names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # pretrained_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = pretrained_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, pretrained_cfg[n]) + + +def filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + Args: + pretrained_cfg: input pretrained cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if pretrained_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg) + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + filter_kwargs(kwargs, names=kwargs_filter) + + +def resolve_pretrained_cfg(variant: str, pretrained_cfg=None): + if pretrained_cfg and isinstance(pretrained_cfg, dict): + # highest priority, pretrained_cfg available and passed as arg + return deepcopy(pretrained_cfg) + # fallback to looking up pretrained cfg in model registry by variant identifier + pretrained_cfg = get_pretrained_cfg(variant) + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {variant} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = dict( + url='', + num_classes=1000, + input_size=(3, 224, 224), + pool_size=None, + crop_pct=.9, + interpolation='bicubic', + first_conv='', + classifier='', + ) + return pretrained_cfg + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + pretrained_cfg: Optional[Dict] = None, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[Dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + pretrained_custom_load: bool = False, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretrained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + pretrained_cfg (dict): model's pretrained weight/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + + # resolve and update model pretrained config and model kwargs + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg) + update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter) + pretrained_cfg.setdefault('architecture', variant) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Build the model + model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_custom_load: + # FIXME improve custom load trigger + load_custom_pretrained(model, pretrained_cfg=pretrained_cfg) + else: + load_pretrained( + model, + pretrained_cfg=pretrained_cfg, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat + + return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module + + +def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): + if module._parameters and not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules_with_params( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if module._parameters and depth_first and include_root: + yield name, module + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects, + group_matcher: Union[Dict, Callable], + output_values: bool = False, + reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return ord, + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if output_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not output_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) + + +def group_modules( + module: nn.Module, + group_matcher, + output_values=False, + reverse=False, +): + return group_with_matcher( + named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) + + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, + preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + + +def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): + prefix_is_tuple = isinstance(prefix, tuple) + if isinstance(module_types, str): + if module_types == 'container': + module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict) + else: + module_types = (nn.Sequential,) + for name, module in named_modules: + if depth and isinstance(module, module_types): + yield from flatten_modules( + module.named_children(), + depth - 1, + prefix=(name,) if prefix_is_tuple else name, + module_types=module_types, + ) + else: + if prefix_is_tuple: + name = prefix + (name,) + yield name, module + else: + if prefix: + name = '.'.join([prefix, name]) + yield name, module diff --git a/src/custom_timm/models/hrnet.py b/src/custom_timm/models/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..08405e8793f4600a40bcea0cb6d5855e1d2f34b0 --- /dev/null +++ b/src/custom_timm/models/hrnet.py @@ -0,0 +1,858 @@ +""" HRNet + +Copied from https://github.com/HRNet/HRNet-Image-Classification + +Original header: + Copyright (c) Microsoft + Licensed under the MIT License. + Written by Bin Xiao (Bin.Xiao@microsoft.com) + Modified by Ke Sun (sunk@mail.ustc.edu.cn) +""" +import logging +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureInfo +from .helpers import build_model_with_cfg, pretrained_cfg_for_features +from .layers import create_classifier +from .registry import register_model +from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE + +_BN_MOMENTUM = 0.1 +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hrnet_w18_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'), + 'hrnet_w18_small_v2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'), + 'hrnet_w18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'), + 'hrnet_w30': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'), + 'hrnet_w32': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'), + 'hrnet_w40': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'), + 'hrnet_w44': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'), + 'hrnet_w48': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'), + 'hrnet_w64': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'), +} + +cfg_cls = dict( + hrnet_w18_small=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(1,), + NUM_CHANNELS=(32,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(16, 32), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=1, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(16, 32, 64), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=1, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(16, 32, 64, 128), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18_small_v2=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(2,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=3, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=2, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w30=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(30, 60), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(30, 60, 120), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(30, 60, 120, 240), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w32=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(32, 64), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(32, 64, 128), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(32, 64, 128, 256), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w40=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(40, 80), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(40, 80, 160), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(40, 80, 160, 320), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w44=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(44, 88), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(44, 88, 176), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(44, 88, 176, 352), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w48=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(48, 96), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(48, 96, 192), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(48, 96, 192, 384), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w64=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(64, 128), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(64, 128, 256), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(64, 128, 256, 512), + FUSE_METHOD='SUM', + ), + ) +) + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_in_chs, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_in_chs, num_channels) + + self.num_in_chs = num_in_chs + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.fuse_act = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, num_in_chs, num_channels): + error_msg = '' + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) + elif num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels)) + elif num_branches != len(num_in_chs): + error_msg = 'NUM_BRANCHES({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs)) + if error_msg: + _logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_in_chs[branch_index], num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)] + self.num_in_chs[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_in_chs[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return nn.Identity() + + num_branches = self.num_branches + num_in_chs = self.num_in_chs + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM), + nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(nn.Identity()) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_in_chs[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_in_chs[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_in_chs(self): + return self.num_in_chs + + def forward(self, x: List[torch.Tensor]): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i, branch in enumerate(self.branches): + x[i] = branch(x[i]) + + x_fuse = [] + for i, fuse_outer in enumerate(self.fuse_layers): + y = x[0] if i == 0 else fuse_outer[0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + fuse_outer[j](x[j]) + x_fuse.append(self.fuse_act(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'): + super(HighResolutionNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + + stem_width = cfg['STEM_WIDTH'] + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) + self.act2 = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = cfg['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + + self.head = head + self.head_channels = None # set if _make_head called + if head == 'classification': + # Classification Head + self.num_features = 2048 + self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + elif head == 'incre': + self.num_features = 2048 + self.incre_modules, _, _ = self._make_head(pre_stage_channels, True) + else: + self.incre_modules = None + self.num_features = 256 + + curr_stride = 2 + # module names aren't actually valid here, hook or FeatureNet based extraction would not work + self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')] + for i, c in enumerate(self.head_channels if self.head_channels else num_channels): + curr_stride *= 2 + c = c * 4 if self.head_channels else c # head block expansion factor of 4 + self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')] + + self.init_weights() + + def _make_head(self, pre_stage_channels, incre_only=False): + head_block = Bottleneck + self.head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1)) + incre_modules = nn.ModuleList(incre_modules) + if incre_only: + return incre_modules, None, None + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = self.head_channels[i] * head_block.expansion + out_channels = self.head_channels[i + 1] * head_block.expansion + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=self.head_channels[3] * head_block.expansion, + out_channels=self.num_features, kernel_size=1, stride=1, padding=0 + ), + nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(inplanes, planes, stride, downsample)] + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + reset_multi_scale_output = multi_scale_output or i < num_modules - 1 + modules.append(HighResolutionModule( + num_branches, block, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output) + ) + num_in_chs = modules[-1].get_num_in_chs() + + return nn.Sequential(*modules), num_in_chs + + @torch.jit.ignore + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv[12]|bn[12]', + blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [ + (r'^layer(\d+)\.(\d+)', None), + (r'^stage(\d+)\.(\d+)', None), + (r'^transition(\d+)', (99999,)), + ], + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "gradient checkpointing not supported" + + @torch.jit.ignore + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def stages(self, x) -> List[torch.Tensor]: + x = self.layer1(x) + + xl = [t(x) for i, t in enumerate(self.transition1)] + yl = self.stage2(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)] + yl = self.stage3(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)] + yl = self.stage4(xl) + return yl + + def forward_features(self, x): + # Stem + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + # Stages + yl = self.stages(x) + if self.incre_modules is None or self.downsamp_modules is None: + return yl + y = self.incre_modules[0](yl[0]) + for i, down in enumerate(self.downsamp_modules): + y = self.incre_modules[i + 1](yl[i + 1]) + down(y) + y = self.final_layer(y) + return y + + def forward_head(self, x, pre_logits: bool = False): + # Classification Head + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return x if pre_logits else self.classifier(x) + + def forward(self, x): + y = self.forward_features(x) + x = self.forward_head(y) + return x + + +class HighResolutionNetFeatures(HighResolutionNet): + """HighResolutionNet feature extraction + + The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so. + It would be more complicated to use the FeatureNet helpers. + + The `feature_location=incre` allows grabbing increased channel count features using part of the + classification head. If `feature_location=''` the default HRNet features are returned. First stem + conv is used for stride 2 features. + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, + feature_location='incre', out_indices=(0, 1, 2, 3, 4)): + assert feature_location in ('incre', '') + super(HighResolutionNetFeatures, self).__init__( + cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, + drop_rate=drop_rate, head=feature_location) + self.feature_info = FeatureInfo(self.feature_info, out_indices) + self._out_idx = {i for i in out_indices} + + def forward_features(self, x): + assert False, 'Not supported' + + def forward(self, x) -> List[torch.tensor]: + out = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + if 0 in self._out_idx: + out.append(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + x = self.stages(x) + if self.incre_modules is not None: + x = [incre(f) for f, incre in zip(x, self.incre_modules)] + for i, f in enumerate(x): + if i + 1 in self._out_idx: + out.append(f) + return out + + +def _create_hrnet(variant, pretrained, **model_kwargs): + model_cls = HighResolutionNet + features_only = False + kwargs_filter = None + if model_kwargs.pop('features_only', False): + model_cls = HighResolutionNetFeatures + kwargs_filter = ('num_classes', 'global_pool') + features_only = True + model = build_model_with_cfg( + model_cls, variant, pretrained, + model_cfg=cfg_cls[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) + if features_only: + model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) + model.default_cfg = model.pretrained_cfg # backwards compat + return model + + +@register_model +def hrnet_w18_small(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w18_small', pretrained, **kwargs) + + +@register_model +def hrnet_w18_small_v2(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs) + + +@register_model +def hrnet_w18(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w18', pretrained, **kwargs) + + +@register_model +def hrnet_w30(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w30', pretrained, **kwargs) + + +@register_model +def hrnet_w32(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w32', pretrained, **kwargs) + + +@register_model +def hrnet_w40(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w40', pretrained, **kwargs) + + +@register_model +def hrnet_w44(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w44', pretrained, **kwargs) + + +@register_model +def hrnet_w48(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w48', pretrained, **kwargs) + + +@register_model +def hrnet_w64(pretrained=False, **kwargs): + return _create_hrnet('hrnet_w64', pretrained, **kwargs) diff --git a/src/custom_timm/models/hub.py b/src/custom_timm/models/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1a6e5df0279d99b2a57f0762f5214de13dad94 --- /dev/null +++ b/src/custom_timm/models/hub.py @@ -0,0 +1,170 @@ +import json +import logging +import os +from functools import partial +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Union + +import torch +from torch.hub import HASH_REGEX, download_url_to_file, urlparse + +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir + +from custom_timm import __version__ + +try: + from huggingface_hub import (create_repo, get_hf_file_metadata, + hf_hub_download, hf_hub_url, + repo_type_and_id_from_hf_id, upload_folder) + from huggingface_hub.utils import EntryNotFoundError + hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + +_logger = logging.getLogger(__name__) + + +def get_cache_dir(child_dir=''): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) + os.makedirs(model_dir, exist_ok=True) + return model_dir + + +def download_cached_file(url, check_hash=True, progress=False): + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def hf_split(hf_id): + # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme + rev_split = hf_id.split('@') + assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' + hf_model_id = rev_split[0] + hf_revision = rev_split[-1] if len(rev_split) > 1 else None + return hf_model_id, hf_revision + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def _download_from_hf(model_id: str, filename: str): + hf_model_id, hf_revision = hf_split(model_id) + return hf_hub_download(hf_model_id, filename, revision=hf_revision) + + +def load_model_config_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'config.json') + pretrained_cfg = load_cfg_from_json(cached_file) + pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation + pretrained_cfg['source'] = 'hf-hub' + model_name = pretrained_cfg.get('architecture') + return pretrained_cfg, model_name + + +def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, filename) + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict + + +def save_for_hf(model, save_directory, model_config=None): + assert has_hf_hub(True) + model_config = model_config or {} + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / 'pytorch_model.bin' + torch.save(model.state_dict(), weights_path) + + config_path = save_directory / 'config.json' + hf_config = model.pretrained_cfg + hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) + hf_config['num_features'] = model_config.pop('num_features', model.num_features) + hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) + hf_config.update(model_config) + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def push_to_hf_hub( + model, + repo_id: str, + commit_message: str ='Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_config: Optional[dict] = None, +): + # Create repo if doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if README file already exist in repo + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf(model, tmpdir, model_config=model_config) + + # Add readme if does not exist + if not has_readme: + readme_path = Path(tmpdir) / "README.md" + readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}' + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) diff --git a/src/custom_timm/models/inception_resnet_v2.py b/src/custom_timm/models/inception_resnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..ae932786961457dd149817dd58e7d50ba2345b6c --- /dev/null +++ b/src/custom_timm/models/inception_resnet_v2.py @@ -0,0 +1,382 @@ +""" Pytorch Inception-Resnet-V2 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from custom_timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, flatten_modules +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionResnetV2'] + +default_cfgs = { + # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz + 'inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + }, + # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz + 'ens_adv_inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=.001) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, no_relu=False): + super(Block8, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + self.relu = None if no_relu else nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if self.relu is not None: + out = self.relu(out) + return out + + +class InceptionResnetV2(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): + super(InceptionResnetV2, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + assert output_stride == 32 + + self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] + + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] + + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17) + ) + self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] + + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10) + ) + self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] + + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20) + ) + self.block8 = Block8(no_relu=True) + self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] + + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))} + module_map.pop(('classif',)) + + def _matcher(name): + if any([name.startswith(n) for n in ('conv2d_1', 'conv2d_2')]): + return 0 + elif any([name.startswith(n) for n in ('conv2d_3', 'conv2d_4')]): + return 1 + elif any([name.startswith(n) for n in ('block8', 'conv2d_7')]): + return len(module_map) + 1 + else: + for k in module_map.keys(): + if k == tuple(name.split('.')[:len(k)]): + return module_map[k] + return float('inf') + return _matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "checkpointing not supported" + + @torch.jit.ignore + def get_classifier(self): + return self.classif + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return x if pre_logits else self.classif(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): + return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs) + + +@register_model +def inception_resnet_v2(pretrained=False, **kwargs): + r"""InceptionResnetV2 model architecture from the + `"InceptionV4, Inception-ResNet..." ` paper. + """ + return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) + + +@register_model +def ens_adv_inception_resnet_v2(pretrained=False, **kwargs): + r""" Ensemble Adversarially trained InceptionResnetV2 model architecture + As per https://arxiv.org/abs/1705.07204 and + https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. + """ + return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) diff --git a/src/custom_timm/models/inception_v3.py b/src/custom_timm/models/inception_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..1e03afd9af9fbd463c17a9f0c961f73026c779e2 --- /dev/null +++ b/src/custom_timm/models/inception_v3.py @@ -0,0 +1,475 @@ +""" Inception-V3 + +Originally from torchvision Inception3 model +Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from custom_timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules +from .registry import register_model +from .layers import trunc_normal_, create_classifier, Linear + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + # original PyTorch weights, ported from Tensorflow but modified + 'inception_v3': _cfg( + url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + has_aux=True), # checkpoint has aux logit layer weights + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + 'tf_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', + num_classes=1000, has_aux=False, label_offset=1), + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + 'adv_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', + num_classes=1000, has_aux=False, label_offset=1), + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + 'gluon_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', + mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults + std=IMAGENET_DEFAULT_STD, # also works well with inception defaults + has_aux=False, + ) +} + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None): + super(InceptionA, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None): + super(InceptionC, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionE, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +class InceptionV3(nn.Module): + """Inception-V3 with no AuxLogits + FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False): + super(InceptionV3, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.aux_logits = aux_logits + + self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + else: + self.AuxLogits = None + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), + dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), + dict(num_chs=288, reduction=8, module='Mixed_5d'), + dict(num_chs=768, reduction=16, module='Mixed_6e'), + dict(num_chs=2048, reduction=32, module='Mixed_7c'), + ] + + self.num_features = 2048 + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + trunc_normal_(m.weight, std=stddev) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))} + module_map.pop(('fc',)) + + def _matcher(name): + if any([name.startswith(n) for n in ('Conv2d_1', 'Conv2d_2')]): + return 0 + elif any([name.startswith(n) for n in ('Conv2d_3', 'Conv2d_4')]): + return 1 + else: + for k in module_map.keys(): + if k == tuple(name.split('.')[:len(k)]): + return module_map[k] + return float('inf') + return _matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_preaux(self, x): + x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147 + x = self.Pool1(x) # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71 + x = self.Pool2(x) # N x 192 x 35 x 35 + x = self.Mixed_5b(x) # N x 256 x 35 x 35 + x = self.Mixed_5c(x) # N x 288 x 35 x 35 + x = self.Mixed_5d(x) # N x 288 x 35 x 35 + x = self.Mixed_6a(x) # N x 768 x 17 x 17 + x = self.Mixed_6b(x) # N x 768 x 17 x 17 + x = self.Mixed_6c(x) # N x 768 x 17 x 17 + x = self.Mixed_6d(x) # N x 768 x 17 x 17 + x = self.Mixed_6e(x) # N x 768 x 17 x 17 + return x + + def forward_postaux(self, x): + x = self.Mixed_7a(x) # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) # N x 2048 x 8 x 8 + return x + + def forward_features(self, x): + x = self.forward_preaux(x) + x = self.forward_postaux(x) + return x + + def forward_head(self, x): + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +class InceptionV3Aux(InceptionV3): + """InceptionV3 with AuxLogits + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True): + super(InceptionV3Aux, self).__init__( + num_classes, in_chans, drop_rate, global_pool, aux_logits) + + def forward_features(self, x): + x = self.forward_preaux(x) + aux = self.AuxLogits(x) if self.training else None + x = self.forward_postaux(x) + return x, aux + + def forward(self, x): + x, aux = self.forward_features(x) + x = self.forward_head(x) + return x, aux + + +def _create_inception_v3(variant, pretrained=False, **kwargs): + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) + aux_logits = kwargs.pop('aux_logits', False) + if aux_logits: + assert not kwargs.pop('features_only', False) + model_cls = InceptionV3Aux + load_strict = pretrained_cfg['has_aux'] + else: + model_cls = InceptionV3 + load_strict = not pretrained_cfg['has_aux'] + + return build_model_with_cfg( + model_cls, variant, pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_strict=load_strict, + **kwargs) + + +@register_model +def inception_v3(pretrained=False, **kwargs): + # original PyTorch weights, ported from Tensorflow but modified + model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def adv_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def gluon_inception_v3(pretrained=False, **kwargs): + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) + return model diff --git a/src/custom_timm/models/inception_v4.py b/src/custom_timm/models/inception_v4.py new file mode 100644 index 0000000000000000000000000000000000000000..02d7128221c521c245d3c8832923392c43255180 --- /dev/null +++ b/src/custom_timm/models/inception_v4.py @@ -0,0 +1,330 @@ +""" Pytorch Inception-V4 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from custom_timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionV4'] + +default_cfgs = { + 'inception_v4': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'features.0.conv', 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed3a(nn.Module): + def __init__(self): + super(Mixed3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed4a(nn.Module): + def __init__(self): + super(Mixed4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed5a(nn.Module): + def __init__(self): + super(Mixed5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class InceptionA(nn.Module): + def __init__(self): + super(InceptionA, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionA(nn.Module): + def __init__(self): + super(ReductionA, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionB(nn.Module): + def __init__(self): + super(InceptionB, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionB(nn.Module): + def __init__(self): + super(ReductionB, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionC(nn.Module): + def __init__(self): + super(InceptionC, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): + super(InceptionV4, self).__init__() + assert output_stride == 32 + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + + self.features = nn.Sequential( + BasicConv2d(in_chans, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed3a(), + Mixed4a(), + Mixed5a(), + InceptionA(), + InceptionA(), + InceptionA(), + InceptionA(), + ReductionA(), # Mixed6a + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + ReductionB(), # Mixed7a + InceptionC(), + InceptionC(), + InceptionC(), + ) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='features.2'), + dict(num_chs=160, reduction=4, module='features.3'), + dict(num_chs=384, reduction=8, module='features.9'), + dict(num_chs=1024, reduction=16, module='features.17'), + dict(num_chs=1536, reduction=32, module='features.21'), + ] + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^features\.[012]\.', + blocks=r'^features\.(\d+)' + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return x if pre_logits else self.last_linear(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_inception_v4(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionV4, variant, pretrained, + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def inception_v4(pretrained=False, **kwargs): + return _create_inception_v4('inception_v4', pretrained, **kwargs) diff --git a/src/custom_timm/models/levit.py b/src/custom_timm/models/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..3f8a360681a6d7381eb28d1ec716bb061fb7e5e5 --- /dev/null +++ b/src/custom_timm/models/levit.py @@ -0,0 +1,592 @@ +""" LeViT + +Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference` + - https://arxiv.org/abs/2104.01136 + +@article{graham2021levit, + title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference}, + author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze}, + journal={arXiv preprint arXiv:22104.01136}, + year={2021} +} + +Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow. + +This version combines both conv/linear models and fixes torchscript compatibility. + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" + +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools +from copy import deepcopy +from functools import partial +from typing import Dict + +import torch +import torch.nn as nn + +from custom_timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .helpers import build_model_with_cfg, checkpoint_seq +from .layers import to_ntuple, get_act_layer +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), + **kwargs + } + + +default_cfgs = dict( + levit_128s=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + ), + levit_128=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + ), + levit_192=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + ), + levit_256=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + ), + levit_384=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + ), + + levit_256d=_cfg(url='', classifier='head.l'), +) + +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), + + levit_256d=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6)), +) + +__all__ = ['Levit'] + + +@register_model +def levit_128s(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_128(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_192(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_256(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_384(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_256d(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_256d', pretrained=pretrained, use_conv=use_conv, distilled=False, **kwargs) + + +class ConvNorm(nn.Sequential): + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', nn.Conv2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class LinearNorm(nn.Sequential): + def __init__(self, in_features, out_features, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', nn.Linear(in_features, out_features, bias=False)) + self.add_module('bn', nn.BatchNorm1d(out_features)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + x = self.c(x) + return self.bn(x.flatten(0, 1)).reshape_as(x) + + +class NormLinear(nn.Sequential): + def __init__(self, in_features, out_features, bias=True, std=0.02): + super().__init__() + self.add_module('bn', nn.BatchNorm1d(in_features)) + self.add_module('l', nn.Linear(in_features, out_features, bias=bias)) + + trunc_normal_(self.l.weight, std=std) + if self.l.bias is not None: + nn.init.constant_(self.l.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b16(in_chs, out_chs, activation, resolution=224): + return nn.Sequential( + ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Subsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class Attention(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): + super().__init__() + ln_layer = ConvNorm if use_conv else LinearNorm + self.use_conv = use_conv + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.key_attn_dim = key_dim * num_heads + self.val_dim = int(attn_ratio * key_dim) + self.val_attn_dim = int(attn_ratio * key_dim) * num_heads + + self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, resolution=resolution) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.val_attn_dim, dim, bn_weight_init=0, resolution=resolution) + ) + + self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution ** 2)) + pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1) + rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * resolution) + rel_pos[1] + self.register_buffer('attention_bias_idxs', rel_pos) + self.ab = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): # x (B,C,H,W) + if self.use_conv: + B, C, H, W = x.shape + q, k, v = self.qkv(x).view( + B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + else: + B, N, C = x.shape + q, k, v = self.qkv(x).view( + B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 3, 1) + v = v.permute(0, 2, 1, 3) + + attn = q @ k * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim) + x = self.proj(x) + return x + + +class AttentionSubsample(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_out=7, use_conv=False): + super().__init__() + self.stride = stride + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.key_attn_dim = key_dim * num_heads + self.val_dim = int(attn_ratio * key_dim) + self.val_attn_dim = self.val_dim * self.num_heads + self.resolution = resolution + self.resolution_out_area = resolution_out ** 2 + + self.use_conv = use_conv + if self.use_conv: + ln_layer = ConvNorm + sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + else: + ln_layer = LinearNorm + sub_layer = partial(Subsample, resolution=resolution) + + self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, resolution=resolution) + self.q = nn.Sequential( + sub_layer(stride=stride), + ln_layer(in_dim, self.key_attn_dim, resolution=resolution_out) + ) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.val_attn_dim, out_dim, resolution=resolution_out) + ) + + self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.resolution ** 2)) + k_pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1) + q_pos = torch.stack(torch.meshgrid( + torch.arange(0, resolution, step=stride), + torch.arange(0, resolution, step=stride))).flatten(1) + rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * resolution) + rel_pos[1] + self.register_buffer('attention_bias_idxs', rel_pos) + + self.ab = {} # per-device attention_biases cache + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): + if self.use_conv: + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_out_area) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution, self.resolution) + else: + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3) + k = k.permute(0, 2, 3, 1) # BHCN + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_out_area, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + + attn = q @ k * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim) + x = self.proj(x) + return x + + +class Levit(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + + NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems + w/ train scripts that don't take tuple outputs, + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=(192,), + key_dim=64, + depth=(12,), + num_heads=(3,), + attn_ratio=2, + mlp_ratio=2, + hybrid_backbone=None, + down_ops=None, + act_layer='hard_swish', + attn_act_layer='hard_swish', + use_conv=False, + global_pool='avg', + drop_rate=0., + drop_path_rate=0.): + super().__init__() + act_layer = get_act_layer(act_layer) + attn_act_layer = get_act_layer(attn_act_layer) + ln_layer = ConvNorm if use_conv else LinearNorm + self.use_conv = use_conv + if isinstance(img_size, tuple): + # FIXME origin impl passes single img/res dim through whole hierarchy, + # not sure this model will be used enough to spend time fixing it. + assert img_size[0] == img_size[1] + img_size = img_size[0] + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.grad_checkpointing = False + + num_stages = len(embed_dim) + assert len(depth) == len(num_heads) == num_stages + key_dim = to_ntuple(num_stages)(key_dim) + attn_ratio = to_ntuple(num_stages)(attn_ratio) + mlp_ratio = to_ntuple(num_stages)(mlp_ratio) + down_ops = down_ops or ( + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), + ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), + ('',) + ) + + self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) + + self.blocks = [] + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention( + ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, + resolution=resolution, use_conv=use_conv), + drop_path_rate)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(ed, h, resolution=resolution), + act_layer(), + ln_layer(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path_rate)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_out = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_out=resolution_out, use_conv=use_conv)) + resolution = resolution_out + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(embed_dim[i + 1], h, resolution=resolution), + act_layer(), + ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path_rate)) + self.blocks = nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None, distillation=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +class LevitDistilled(Levit): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() + self.distilled_training = False # must set this True to train w/ distillation token + + @torch.jit.ignore + def get_classifier(self): + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=None, distillation=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + + def forward_head(self, x): + if self.global_pool == 'avg': + x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1) + x, x_dist = self.head(x), self.head_dist(x) + if self.distilled_training and self.training and not torch.jit.is_scripting(): + # only return separate classification predictions when training in distilled mode + return x, x_dist + else: + # during standard train/finetune, inference average the classifier predictions + return (x + x_dist) / 2 + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + D = model.state_dict() + for k in state_dict.keys(): + if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2: + state_dict[k] = state_dict[k][:, :, None, None] + return state_dict + + +def create_levit(variant, pretrained=False, distilled=True, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model_cfg = dict(**model_cfgs[variant], **kwargs) + model = build_model_with_cfg( + LevitDistilled if distilled else Levit, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **model_cfg) + return model + diff --git a/src/custom_timm/optim/__init__.py b/src/custom_timm/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee4958eb562bcfe06a5da72be4b76ee610a0ccc --- /dev/null +++ b/src/custom_timm/optim/__init__.py @@ -0,0 +1,15 @@ +from .adabelief import AdaBelief +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .adamw import AdamW +from .lamb import Lamb +from .lars import Lars +from .lookahead import Lookahead +from .madgrad import MADGRAD +from .nadam import Nadam +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs diff --git a/src/custom_timm/optim/adabelief.py b/src/custom_timm/optim/adabelief.py new file mode 100644 index 0000000000000000000000000000000000000000..951d715cc0b605df2f7313c95840b7784c4d0a70 --- /dev/null +++ b/src/custom_timm/optim/adabelief.py @@ -0,0 +1,201 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdaBelief(Optimizer): + r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-16) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + decoupled_decay (boolean, optional): (default: True) If set as True, then + the optimizer uses decoupled weight decay as in AdamW + fixed_decay (boolean, optional): (default: False) This is used when weight_decouple + is set as True. + When fixed_decay == True, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay$. + When fixed_decay == False, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the + weight decay ratio decreases with learning rate (lr). + rectify (boolean, optional): (default: True) If set as True, then perform the rectified + update similar to RAdam + degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update + when variance of gradient is high + reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 + + For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer' + For example train/args for EfficientNet see these gists + - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037 + - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3 + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, + decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, + degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify, + fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)]) + super(AdaBelief, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdaBelief, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + amsgrad = group['amsgrad'] + + # State initialization + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError( + 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p_fp32) + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p_fp32) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p_fp32) + + # perform weight decay, check if decoupled weight decay + if group['decoupled_decay']: + if not group['fixed_decay']: + p_fp32.mul_(1.0 - group['lr'] * group['weight_decay']) + else: + p_fp32.mul_(1.0 - group['weight_decay']) + else: + if group['weight_decay'] != 0: + grad.add_(p_fp32, alpha=group['weight_decay']) + + # get current state variable + exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Update first and second moment running average + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + grad_residual = grad - exp_avg + exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) + + if amsgrad: + max_exp_avg_var = state['max_exp_avg_var'] + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + # update + if not group['rectify']: + # Default update + step_size = group['lr'] / bias_correction1 + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) + else: + # Rectified update, forked from RAdam + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + num_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + num_sma_max = 2 / (1 - beta2) - 1 + num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = num_sma + + # more conservative since it's an approximated value + if num_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * + (num_sma - 4) / (num_sma_max - 4) * + (num_sma - 2) / num_sma * + num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) + elif group['degenerated_to_sgd']: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + if num_sma >= 5: + denom = exp_avg_var.sqrt().add_(group['eps']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + elif step_size > 0: + p_fp32.add_(exp_avg, alpha=-step_size * group['lr']) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss diff --git a/src/custom_timm/optim/adafactor.py b/src/custom_timm/optim/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..06057433a9bffa555bdc13b27a1c56cff26acf15 --- /dev/null +++ b/src/custom_timm/optim/adafactor.py @@ -0,0 +1,167 @@ +""" Adafactor Optimizer + +Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + +Original header/copyright below. + +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import math + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = not lr + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + + factored, use_first_moment = self._get_options(group, grad.shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) + update = exp_avg + + if group['weight_decay'] != 0: + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t) + + p_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss diff --git a/src/custom_timm/optim/adahessian.py b/src/custom_timm/optim/adahessian.py new file mode 100644 index 0000000000000000000000000000000000000000..985c67ca686a65f61f5c5b1a7db3e5bba815a19b --- /dev/null +++ b/src/custom_timm/optim/adahessian.py @@ -0,0 +1,156 @@ +""" AdaHessian Optimizer + +Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py +Originally licensed MIT, Copyright 2020, David Samuel +""" +import torch + + +class Adahessian(torch.optim.Optimizer): + """ + Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate (default: 0.1) + betas ((float, float), optional): coefficients used for computing running averages of gradient and the + squared hessian trace (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) + hessian_power (float, optional): exponent of the hessian trace (default: 1.0) + update_each (int, optional): compute the hessian trace approximation only after *this* number of steps + (to save time) (default: 1) + n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) + """ + + def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, + hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= hessian_power <= 1.0: + raise ValueError(f"Invalid Hessian power value: {hessian_power}") + + self.n_samples = n_samples + self.update_each = update_each + self.avg_conv_kernel = avg_conv_kernel + + # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training + self.seed = 2147483647 + self.generator = torch.Generator().manual_seed(self.seed) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + super(Adahessian, self).__init__(params, defaults) + + for p in self.get_params(): + p.hess = 0.0 + self.state[p]["hessian step"] = 0 + + @property + def is_second_order(self): + return True + + def get_params(self): + """ + Gets all parameters in all param_groups with gradients + """ + + return (p for group in self.param_groups for p in group['params'] if p.requires_grad) + + def zero_hessian(self): + """ + Zeros out the accumalated hessian traces. + """ + + for p in self.get_params(): + if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: + p.hess.zero_() + + @torch.no_grad() + def set_hessian(self): + """ + Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. + """ + + params = [] + for p in filter(lambda p: p.grad is not None, self.get_params()): + if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step + params.append(p) + self.state[p]["hessian step"] += 1 + + if len(params) == 0: + return + + if self.generator.device != params[0].device: # hackish way of casting the generator to the right device + self.generator = torch.Generator(params[0].device).manual_seed(self.seed) + + grads = [p.grad for p in params] + + for i in range(self.n_samples): + # Rademacher distribution {-1.0, 1.0} + zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] + h_zs = torch.autograd.grad( + grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) + for h_z, z, p in zip(h_zs, zs, params): + p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + Arguments: + closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) + """ + + loss = None + if closure is not None: + loss = closure() + + self.zero_hessian() + self.set_hessian() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None or p.hess is None: + continue + + if self.avg_conv_kernel and p.dim() == 4: + p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() + + # Perform correct stepweight decay as in AdamW + p.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + # State initialization + if len(state) == 1: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of Hessian diagonal square values + state['exp_hessian_diag_sq'] = torch.zeros_like(p) + + exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + k = group['hessian_power'] + denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) + + # make update + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/src/custom_timm/optim/adamp.py b/src/custom_timm/optim/adamp.py new file mode 100644 index 0000000000000000000000000000000000000000..ee187633ab745dbb0344dcdc3dcb1cf40e6ae5e9 --- /dev/null +++ b/src/custom_timm/optim/adamp.py @@ -0,0 +1,105 @@ +""" +AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +import math + + +def _channel_view(x) -> torch.Tensor: + return x.reshape(x.size(0), -1) + + +def _layer_view(x) -> torch.Tensor: + return x.reshape(1, -1) + + +def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): + wd = 1. + expand_size = (-1,) + (1,) * (len(p.shape) - 1) + for view_func in [_channel_view, _layer_view]: + param_view = view_func(p) + grad_view = view_func(grad) + cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_() + + # FIXME this is a problem for PyTorch XLA + if cosine_sim.max() < delta / math.sqrt(param_view.size(1)): + p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size) + wd = wd_ratio + return perturb, wd + + return perturb, wd + + +class AdamP(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + super(AdamP, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + beta1, beta2 = group['betas'] + nesterov = group['nesterov'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + # Adam + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = group['lr'] / bias_correction1 + + if nesterov: + perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom + else: + perturb = exp_avg / denom + + # Projection + wd_ratio = 1. + if len(p.shape) > 1: + perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if group['weight_decay'] > 0: + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) + + # Step + p.add_(perturb, alpha=-step_size) + + return loss diff --git a/src/custom_timm/optim/adamw.py b/src/custom_timm/optim/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..66478bc6ef3c50ab9d40cabb0cfb2bd24277c815 --- /dev/null +++ b/src/custom_timm/optim/adamw.py @@ -0,0 +1,122 @@ +""" AdamW Optimizer +Impl copied from PyTorch master + +NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed +someday +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/src/custom_timm/optim/lamb.py b/src/custom_timm/optim/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..12c7c49b8a01ef793c97654ac938259ca6508449 --- /dev/null +++ b/src/custom_timm/optim/lamb.py @@ -0,0 +1,192 @@ +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): + defaults = dict( + lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, + grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, + trust_clip=trust_clip, always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group['step'] + bias_correction2 = 1 - beta2 ** group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/src/custom_timm/optim/lars.py b/src/custom_timm/optim/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..38ca9e0b5cb90855104ce7b5ff358cb7fa343f12 --- /dev/null +++ b/src/custom_timm/optim/lars.py @@ -0,0 +1,135 @@ +""" PyTorch LARS / LARC Optimizer + +An implementation of LARS (SGD) + LARC in PyTorch + +Based on: + * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + +Additional cleanup and modifications to properly support PyTorch XLA. + +Copyright 2021 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer + + +class Lars(Optimizer): + """ LARS for PyTorch + + Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate (default: 1.0). + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) + eps (float): eps for division denominator (default: 1e-8) + trust_clip (bool): enable LARC trust ratio clipping (default: False) + always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False) + """ + + def __init__( + self, + params, + lr=1.0, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + trust_coeff=0.001, + eps=1e-8, + trust_clip=False, + always_adapt=False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + trust_coeff=trust_coeff, + eps=eps, + trust_clip=trust_clip, + always_adapt=always_adapt, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + trust_coeff = group['trust_coeff'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + + # apply LARS LR adaptation, LARC clipping, weight decay + # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + if weight_decay != 0 or group['always_adapt']: + w_norm = p.norm(2.0) + g_norm = grad.norm(2.0) + trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, trust_ratio, one_tensor), + one_tensor, + ) + if group['trust_clip']: + trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) + grad.add_(p, alpha=weight_decay) + grad.mul_(trust_ratio) + + # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(grad).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + p.add_(grad, alpha=-group['lr']) + + return loss \ No newline at end of file diff --git a/src/custom_timm/optim/lookahead.py b/src/custom_timm/optim/lookahead.py new file mode 100644 index 0000000000000000000000000000000000000000..462c3acd247016a94acd39a27dd44f29ae854d31 --- /dev/null +++ b/src/custom_timm/optim/lookahead.py @@ -0,0 +1,61 @@ +""" Lookahead Optimizer Wrapper. +Implementation modified from: https://github.com/alphadl/lookahead.pytorch +Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + def __init__(self, base_optimizer, alpha=0.5, k=6): + # NOTE super().__init__() not called on purpose + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) + self._base_optimizer = base_optimizer + self.param_groups = base_optimizer.param_groups + self.defaults = base_optimizer.defaults + self.defaults.update(defaults) + self.state = defaultdict(dict) + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self._base_optimizer.param_groups: + group.setdefault(name, default) + + @torch.no_grad() + def update_slow(self, group): + for fast_p in group["params"]: + if fast_p.grad is None: + continue + param_state = self._base_optimizer.state[fast_p] + if 'lookahead_slow_buff' not in param_state: + param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) + param_state['lookahead_slow_buff'].copy_(fast_p) + slow = param_state['lookahead_slow_buff'] + slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) + fast_p.copy_(slow) + + def sync_lookahead(self): + for group in self._base_optimizer.param_groups: + self.update_slow(group) + + @torch.no_grad() + def step(self, closure=None): + loss = self._base_optimizer.step(closure) + for group in self._base_optimizer.param_groups: + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) + return loss + + def state_dict(self): + return self._base_optimizer.state_dict() + + def load_state_dict(self, state_dict): + self._base_optimizer.load_state_dict(state_dict) + self.param_groups = self._base_optimizer.param_groups diff --git a/src/custom_timm/optim/madgrad.py b/src/custom_timm/optim/madgrad.py new file mode 100644 index 0000000000000000000000000000000000000000..a76713bf27ed1daf0ce598ac5f25c6238c7fdb57 --- /dev/null +++ b/src/custom_timm/optim/madgrad.py @@ -0,0 +1,184 @@ +""" PyTorch MADGRAD optimizer + +MADGRAD: https://arxiv.org/abs/2101.11075 + +Code from: https://github.com/facebookresearch/madgrad +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.optim + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t +else: + _params_t = Any + + +class MADGRAD(torch.optim.Optimizer): + """ + MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic + Optimization. + + .. _MADGRAD: https://arxiv.org/abs/2101.11075 + + MADGRAD is a general purpose optimizer that can be used in place of SGD or + Adam may converge faster and generalize better. Currently GPU-only. + Typically, the same learning rate schedule that is used for SGD or Adam may + be used. The overall learning rate is not comparable to either method and + should be determined by a hyper-parameter sweep. + + MADGRAD requires less weight decay than other methods, often as little as + zero. Momentum values used for SGD or Adam's beta1 should work here also. + + On sparse problems both weight_decay and momentum should be set to 0. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate (default: 1e-2). + momentum (float): + Momentum value in the range [0,1) (default: 0.9). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). + """ + + def __init__( + self, + params: _params_t, + lr: float = 1e-2, + momentum: float = 0.9, + weight_decay: float = 0, + eps: float = 1e-6, + decoupled_decay: bool = False, + ): + if momentum < 0 or momentum >= 1: + raise ValueError(f"Momentum {momentum} must be in the range [0,1]") + if lr <= 0: + raise ValueError(f"Learning rate {lr} must be positive") + if weight_decay < 0: + raise ValueError(f"Weight decay {weight_decay} must be non-negative") + if eps < 0: + raise ValueError(f"Eps must be non-negative") + + defaults = dict( + lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self) -> bool: + return False + + @property + def supports_flat_params(self) -> bool: + return True + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + eps = group['eps'] + lr = group['lr'] + eps + weight_decay = group['weight_decay'] + momentum = group['momentum'] + ck = 1 - momentum + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if momentum != 0.0 and grad.is_sparse: + raise RuntimeError("momentum != 0 is not compatible with sparse gradients") + + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + state['grad_sum_sq'] = torch.zeros_like(p) + state['s'] = torch.zeros_like(p) + if momentum != 0: + state['x0'] = torch.clone(p).detach() + + state['step'] += 1 + grad_sum_sq = state['grad_sum_sq'] + s = state['s'] + lamb = lr * math.sqrt(state['step']) + + # Apply weight decay + if weight_decay != 0: + if group['decoupled_decay']: + p.mul_(1.0 - group['lr'] * weight_decay) + else: + if grad.is_sparse: + raise RuntimeError("weight_decay option is not compatible with sparse gradients") + grad.add_(p, alpha=weight_decay) + + if grad.is_sparse: + grad = grad.coalesce() + grad_val = grad._values() + + p_masked = p.sparse_mask(grad) + grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) + s_masked = s.sparse_mask(grad) + + # Compute x_0 from other known quantities + rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) + x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) + + # Dense + sparse op + grad_sq = grad * grad + grad_sum_sq.add_(grad_sq, alpha=lamb) + grad_sum_sq_masked.add_(grad_sq, alpha=lamb) + + rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) + + s.add_(grad, alpha=lamb) + s_masked._values().add_(grad_val, alpha=lamb) + + # update masked copy of p + p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) + # Copy updated masked p to dense p using an add operation + p_masked._values().add_(p_kp1_masked_vals, alpha=-1) + p.add_(p_masked, alpha=-1) + else: + if momentum == 0: + # Compute x_0 from other known quantities + rms = grad_sum_sq.pow(1 / 3).add_(eps) + x0 = p.addcdiv(s, rms, value=1) + else: + x0 = state['x0'] + + # Accumulate second moments + grad_sum_sq.addcmul_(grad, grad, value=lamb) + rms = grad_sum_sq.pow(1 / 3).add_(eps) + + # Update s + s.add_(grad, alpha=lamb) + + # Step + if momentum == 0: + p.copy_(x0.addcdiv(s, rms, value=-1)) + else: + z = x0.addcdiv(s, rms, value=-1) + + # p is a moving average of z + p.mul_(1 - ck).add_(z, alpha=ck) + + return loss diff --git a/src/custom_timm/optim/nadam.py b/src/custom_timm/optim/nadam.py new file mode 100644 index 0000000000000000000000000000000000000000..6268d5d451ed2fe26b47e46476dc1feee7da9649 --- /dev/null +++ b/src/custom_timm/optim/nadam.py @@ -0,0 +1,92 @@ +import math + +import torch +from torch.optim.optimizer import Optimizer + + +class Nadam(Optimizer): + """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). + + It has been proposed in `Incorporating Nesterov Momentum into Adam`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + schedule_decay (float, optional): momentum schedule decay (default: 4e-3) + + __ http://cs229.stanford.edu/proj2015/054_report.pdf + __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf + + Originally taken from: https://github.com/pytorch/pytorch/pull/1408 + NOTE: Has potential issues but does work well on some problems. + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, schedule_decay=4e-3): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay) + super(Nadam, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['m_schedule'] = 1. + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + # Warming momentum schedule + m_schedule = state['m_schedule'] + schedule_decay = group['schedule_decay'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + eps = group['eps'] + state['step'] += 1 + t = state['step'] + bias_correction2 = 1 - beta2 ** t + + if group['weight_decay'] != 0: + grad = grad.add(p, alpha=group['weight_decay']) + + momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay))) + momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) + m_schedule_new = m_schedule * momentum_cache_t + m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 + state['m_schedule'] = m_schedule_new + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new)) + p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next)) + + return loss diff --git a/src/custom_timm/optim/nvnovograd.py b/src/custom_timm/optim/nvnovograd.py new file mode 100644 index 0000000000000000000000000000000000000000..fda3f4a620fcca5593034dfb9683f2c8f3b78ac1 --- /dev/null +++ b/src/custom_timm/optim/nvnovograd.py @@ -0,0 +1,120 @@ +""" Nvidia NovoGrad Optimizer. +Original impl by Nvidia from Jasper example: + - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NvNovoGrad(Optimizer): + """ + Implements Novograd algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.95, 0.98)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging: gradient averaging + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, + weight_decay=0, grad_averaging=False, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad) + + super(NvNovoGrad, self).__init__(params, defaults) + + def __setstate__(self, state): + super(NvNovoGrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Sparse gradients are not supported.') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + norm = torch.sum(torch.pow(grad, 2)) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + grad.div_(denom) + if group['weight_decay'] != 0: + grad.add_(p, alpha=group['weight_decay']) + if group['grad_averaging']: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + p.add_(exp_avg, alpha=-group['lr']) + + return loss diff --git a/src/custom_timm/optim/optim_factory.py b/src/custom_timm/optim/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4acaec67bb094c870b5ecd34b41b14a172de8bdd --- /dev/null +++ b/src/custom_timm/optim/optim_factory.py @@ -0,0 +1,340 @@ +""" Optimizer Factory w/ Custom Weight Decay +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +from itertools import islice +from typing import Optional, Callable, Tuple + +import torch +import torch.nn as nn +import torch.optim as optim + +from custom_timm.models.helpers import group_parameters + +from .adabelief import AdaBelief +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .lamb import Lamb +from .lars import Lars +from .lookahead import Lookahead +from .madgrad import MADGRAD +from .nadam import Nadam +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP + +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False + +_logger = logging.getLogger(__name__) + + +def param_groups_weight_decay( + model: nn.Module, + weight_decay=1e-5, + no_weight_decay_list=() +): + no_weight_decay_list = set(no_weight_decay_list) + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + no_decay.append(param) + else: + decay.append(param) + + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def _group(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def _layer_map(model, layers_per_group=12, num_groups=None): + def _in_head(n, hp): + if not hp: + return True + elif isinstance(hp, (tuple, list)): + return any([n.startswith(hpi) for hpi in hp]) + else: + return n.startswith(hp) + + head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None) + names_trunk = [] + names_head = [] + for n, _ in model.named_parameters(): + names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) + + # group non-head layers + num_trunk_layers = len(names_trunk) + if num_groups is not None: + layers_per_group = -(num_trunk_layers // -num_groups) + names_trunk = list(_group(names_trunk, layers_per_group)) + + num_trunk_groups = len(names_trunk) + layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} + layer_map.update({n: num_trunk_groups for n in names_head}) + return layer_map + + +def param_groups_layer_decay( + model: nn.Module, + weight_decay: float = 0.05, + no_weight_decay_list: Tuple[str] = (), + layer_decay: float = .75, + end_layer_decay: Optional[float] = None, + verbose: bool = False, +): + """ + Parameter groups for layer-wise lr decay & weight decay + Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + no_weight_decay_list = set(no_weight_decay_list) + param_group_names = {} # NOTE for debugging + param_groups = {} + + if hasattr(model, 'group_matcher'): + # FIXME interface needs more work + layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) + else: + # fallback + layer_map = _layer_map(model) + num_layers = max(layer_map.values()) + 1 + layer_max = num_layers - 1 + layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if param.ndim == 1 or name in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = layer_map.get(name, layer_max) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_groups: + this_scale = layer_scales[layer_id] + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "param_names": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["param_names"].append(name) + param_groups[group_name]["params"].append(param) + + if verbose: + import json + _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def optimizer_kwargs(cfg): + """ cfg/argparse to kwargs helper + Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. + """ + kwargs = dict( + opt=cfg.opt, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + momentum=cfg.momentum) + if getattr(cfg, 'opt_eps', None) is not None: + kwargs['eps'] = cfg.opt_eps + if getattr(cfg, 'opt_betas', None) is not None: + kwargs['betas'] = cfg.opt_betas + if getattr(cfg, 'layer_decay', None) is not None: + kwargs['layer_decay'] = cfg.layer_decay + if getattr(cfg, 'opt_args', None) is not None: + kwargs.update(cfg.opt_args) + return kwargs + + +def create_optimizer(args, model, filter_bias_and_bn=True): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + return create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + filter_bias_and_bn=filter_bias_and_bn, + ) + + +def create_optimizer_v2( + model_or_params, + opt: str = 'sgd', + lr: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + filter_bias_and_bn: bool = True, + layer_decay: Optional[float] = None, + param_group_fn: Optional[Callable] = None, + **kwargs): + """ Create an optimizer. + + TODO currently the model is passed in and all parameters are selected for optimization. + For more general use an interface that allows selection of parameters to optimize and lr groups, one of: + * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion + * expose the parameters interface and leave it up to caller + + Args: + model_or_params (nn.Module): model containing parameters to optimize + opt: name of optimizer to create + lr: initial learning rate + weight_decay: weight decay to apply in optimizer + momentum: momentum for momentum based optimizers (others may use betas via kwargs) + filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay + **kwargs: extra optimizer specific kwargs to pass through + + Returns: + Optimizer + """ + if isinstance(model_or_params, nn.Module): + # a model was passed in, extract parameters and add weight decays to appropriate layers + no_weight_decay = {} + if hasattr(model_or_params, 'no_weight_decay'): + no_weight_decay = model_or_params.no_weight_decay() + + if param_group_fn: + parameters = param_group_fn(model_or_params) + elif layer_decay is not None: + parameters = param_groups_layer_decay( + model_or_params, + weight_decay=weight_decay, + layer_decay=layer_decay, + no_weight_decay_list=no_weight_decay) + weight_decay = 0. + elif weight_decay and filter_bias_and_bn: + parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) + weight_decay = 0. + else: + parameters = model_or_params.parameters() + else: + # iterable of parameters or param groups passed in + parameters = model_or_params + + opt_lower = opt.lower() + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + + opt_args = dict(weight_decay=weight_decay, **kwargs) + if lr is not None: + opt_args.setdefault('lr', lr) + + # basic SGD & related + if opt_lower == 'sgd' or opt_lower == 'nesterov': + # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'momentum': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) + elif opt_lower == 'sgdp': + optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) + + # adaptive + elif opt_lower == 'adam': + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == 'adamp': + optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) + elif opt_lower == 'nadam': + try: + # NOTE PyTorch >= 1.10 should have native NAdam + optimizer = optim.Nadam(parameters, **opt_args) + except AttributeError: + optimizer = Nadam(parameters, **opt_args) + elif opt_lower == 'radam': + optimizer = RAdam(parameters, **opt_args) + elif opt_lower == 'adamax': + optimizer = optim.Adamax(parameters, **opt_args) + elif opt_lower == 'adabelief': + optimizer = AdaBelief(parameters, rectify=False, **opt_args) + elif opt_lower == 'radabelief': + optimizer = AdaBelief(parameters, rectify=True, **opt_args) + elif opt_lower == 'adadelta': + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == 'adagrad': + opt_args.setdefault('eps', 1e-8) + optimizer = optim.Adagrad(parameters, **opt_args) + elif opt_lower == 'adafactor': + optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == 'lamb': + optimizer = Lamb(parameters, **opt_args) + elif opt_lower == 'lambc': + optimizer = Lamb(parameters, trust_clip=True, **opt_args) + elif opt_lower == 'larc': + optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) + elif opt_lower == 'lars': + optimizer = Lars(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'nlarc': + optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) + elif opt_lower == 'nlars': + optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'madgrad': + optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'madgradw': + optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) + elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': + optimizer = NvNovoGrad(parameters, **opt_args) + elif opt_lower == 'rmsprop': + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) + elif opt_lower == 'rmsproptf': + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) + + # second order + elif opt_lower == 'adahessian': + optimizer = Adahessian(parameters, **opt_args) + + # NVIDIA fused optimizers, require APEX to be installed + elif opt_lower == 'fusedsgd': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'fusedmomentum': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) + elif opt_lower == 'fusedadam': + optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) + elif opt_lower == 'fusedadamw': + optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) + elif opt_lower == 'fusedlamb': + optimizer = FusedLAMB(parameters, **opt_args) + elif opt_lower == 'fusednovograd': + opt_args.setdefault('betas', (0.95, 0.98)) + optimizer = FusedNovoGrad(parameters, **opt_args) + + else: + assert False and "Invalid optimizer" + raise ValueError + + if len(opt_split) > 1: + if opt_split[0] == 'lookahead': + optimizer = Lookahead(optimizer) + + return optimizer diff --git a/src/custom_timm/optim/radam.py b/src/custom_timm/optim/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8d22e06c42e487c831297008851b4adc254d78 --- /dev/null +++ b/src/custom_timm/optim/radam.py @@ -0,0 +1,89 @@ +"""RAdam Optimizer. +Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam +Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)]) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_fp32 = p.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state['step'] += 1 + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + num_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + num_sma_max = 2 / (1 - beta2) - 1 + num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = num_sma + + # more conservative since it's an approximated value + if num_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * + (num_sma - 4) / (num_sma_max - 4) * + (num_sma - 2) / num_sma * + num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr']) + + # more conservative since it's an approximated value + if num_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) + else: + p_fp32.add_(exp_avg, alpha=-step_size) + + p.copy_(p_fp32) + + return loss diff --git a/src/custom_timm/optim/rmsprop_tf.py b/src/custom_timm/optim/rmsprop_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..0817887db380261dfee3fcd4bd155b5d923f5248 --- /dev/null +++ b/src/custom_timm/optim/rmsprop_tf.py @@ -0,0 +1,139 @@ +""" RMSProp modified to behave like Tensorflow impl + +Originally cut & paste from PyTorch RMSProp +https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py +Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE + +Modifications Copyright 2021 Ross Wightman +""" + +import torch +from torch.optim import Optimizer + + +class RMSpropTF(Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + and a few other modifications to closer match Tensorflow for matching hyper-params. + + Noteworthy changes include: + 1. Epsilon applied inside square-root + 2. square_avg initialized to ones + 3. LR scaling of update accumulated in momentum buffer + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing (decay) constant (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow + + """ + + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict( + lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if group['decoupled_decay']: + p.mul_(1. - group['lr'] * group['weight_decay']) + else: + grad = grad.add(p, alpha=group['weight_decay']) + + # Tensorflow order of ops for updating squared avg + square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) + # square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt + # grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original + else: + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + # Tensorflow accumulates the LR scaling in the momentum buffer + if group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) + p.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.add_(buf, alpha=-group['lr']) + else: + p.addcdiv_(grad, avg, value=-group['lr']) + + return loss diff --git a/src/custom_timm/optim/sgdp.py b/src/custom_timm/optim/sgdp.py new file mode 100644 index 0000000000000000000000000000000000000000..baf05fa55c632371498ec53ff679b11023429df6 --- /dev/null +++ b/src/custom_timm/optim/sgdp.py @@ -0,0 +1,70 @@ +""" +SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer, required +import math + +from .adamp import projection + + +class SGDP(Optimizer): + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + defaults = dict( + lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, + nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + super(SGDP, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + + # State initialization + if len(state) == 0: + state['momentum'] = torch.zeros_like(p) + + # SGD + buf = state['momentum'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1. + if len(p.shape) > 1: + d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if weight_decay != 0: + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) + + # Step + p.add_(d_p, alpha=-group['lr']) + + return loss diff --git a/src/custom_timm/scheduler/__init__.py b/src/custom_timm/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1961b88fc3c37cdd8c73f9fddd4bfa1ada95f23 --- /dev/null +++ b/src/custom_timm/scheduler/__init__.py @@ -0,0 +1,8 @@ +from .cosine_lr import CosineLRScheduler +from .multistep_lr import MultiStepLRScheduler +from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler + +from .scheduler_factory import create_scheduler diff --git a/src/custom_timm/scheduler/cosine_lr.py b/src/custom_timm/scheduler/cosine_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..84ee349ec281f89e331be3643b613e158bb3c194 --- /dev/null +++ b/src/custom_timm/scheduler/cosine_lr.py @@ -0,0 +1,119 @@ +""" Cosine Scheduler + +Cosine LR schedule with warmup, cycle/restarts, noise, k-decay. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class CosineLRScheduler(Scheduler): + """ + Cosine decay with restarts. + This is described in the paper https://arxiv.org/abs/1608.03983. + + Inspiration from + https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/src/custom_timm/scheduler/multistep_lr.py b/src/custom_timm/scheduler/multistep_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0ca920307fa4ee6e63340d76ca278b729091e3 --- /dev/null +++ b/src/custom_timm/scheduler/multistep_lr.py @@ -0,0 +1,65 @@ +""" MultiStep LR Scheduler + +Basic multi step LR schedule with warmup, noise. +""" +import torch +import bisect +from custom_timm.scheduler.scheduler import Scheduler +from typing import List + +class MultiStepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: List[int], + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def get_curr_decay_steps(self, t): + # find where in the array t goes, + # assumes self.decay_t is sorted + return bisect.bisect_right(self.decay_t, t+1) + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/src/custom_timm/scheduler/plateau_lr.py b/src/custom_timm/scheduler/plateau_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..cacfab3ce7f073c9a99037ed85259fa3286f51ad --- /dev/null +++ b/src/custom_timm/scheduler/plateau_lr.py @@ -0,0 +1,103 @@ +""" Plateau Scheduler + +Adapts PyTorch plateau scheduler and allows application of noise, warmup. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +from .scheduler import Scheduler + + +class PlateauLRScheduler(Scheduler): + """Decay the LR by a factor every time the validation loss plateaus.""" + + def __init__(self, + optimizer, + decay_rate=0.1, + patience_t=10, + verbose=True, + threshold=1e-4, + cooldown_t=0, + warmup_t=0, + warmup_lr_init=0, + lr_min=0, + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize=True, + ): + super().__init__( + optimizer, + 'lr', + noise_range_t=noise_range_t, + noise_type=noise_type, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) + + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + patience=patience_t, + factor=decay_rate, + verbose=verbose, + threshold=threshold, + cooldown=cooldown_t, + mode=mode, + min_lr=lr_min + ) + + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + self.restore_lr = None + + def state_dict(self): + return { + 'best': self.lr_scheduler.best, + 'last_epoch': self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + self.lr_scheduler.best = state_dict['best'] + if 'last_epoch' in state_dict: + self.lr_scheduler.last_epoch = state_dict['last_epoch'] + + # override the base class step fn completely + def step(self, epoch, metric=None): + if epoch <= self.warmup_t: + lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + super().update_groups(lrs) + else: + if self.restore_lr is not None: + # restore actual LR from before our last noise perturbation before stepping base + for i, param_group in enumerate(self.optimizer.param_groups): + param_group['lr'] = self.restore_lr[i] + self.restore_lr = None + + self.lr_scheduler.step(metric, epoch) # step the base scheduler + + if self._is_apply_noise(epoch): + self._apply_noise(epoch) + + def _apply_noise(self, epoch): + noise = self._calculate_noise(epoch) + + # apply the noise on top of previous LR, cache the old value so we can restore for normal + # stepping of base scheduler + restore_lr = [] + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + restore_lr.append(old_lr) + new_lr = old_lr + old_lr * noise + param_group['lr'] = new_lr + self.restore_lr = restore_lr diff --git a/src/custom_timm/scheduler/poly_lr.py b/src/custom_timm/scheduler/poly_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..9c351be6ed56f8fe130cd391df0a7a7f89c7a96c --- /dev/null +++ b/src/custom_timm/scheduler/poly_lr.py @@ -0,0 +1,116 @@ +""" Polynomial Scheduler + +Polynomial LR schedule with warmup, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +import logging + +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class PolyLRScheduler(Scheduler): + """ Polynomial LR Scheduler w/ warmup, noise, and k-decay + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + power: float = 0.5, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.power = power + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/src/custom_timm/scheduler/scheduler.py b/src/custom_timm/scheduler/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..af20be9b59d2fecfd813785ea6bc06093f57858d --- /dev/null +++ b/src/custom_timm/scheduler/scheduler.py @@ -0,0 +1,117 @@ +from typing import Dict, Any + +import torch + + +class Scheduler: + """ Parameter Scheduler Base Class + A scheduler base class that can be used to schedule any optimizer parameter groups. + + Unlike the builtin PyTorch schedulers, this is intended to be consistently called + * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value + * At the END of each optimizer update, after incrementing the update count, to calculate next update's value + + The schedulers built on this should try to remain as stateless as possible (for simplicity). + + This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' + and -1 values for special behaviour. All epoch and update counts must be tracked in the training + code and explicitly passed in to the schedulers on the corresponding step or step_update call. + + Based on ideas from: + * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler + * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize: bool = True) -> None: + self.optimizer = optimizer + self.param_group_field = param_group_field + self._initial_param_group_field = f"initial_{param_group_field}" + if initialize: + for i, group in enumerate(self.optimizer.param_groups): + if param_group_field not in group: + raise KeyError(f"{param_group_field} missing from param_groups[{i}]") + group.setdefault(self._initial_param_group_field, group[param_group_field]) + else: + for i, group in enumerate(self.optimizer.param_groups): + if self._initial_param_group_field not in group: + raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") + self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] + self.metric = None # any point to having this for all? + self.noise_range_t = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.update_groups(self.base_values) + + def state_dict(self) -> Dict[str, Any]: + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__dict__.update(state_dict) + + def get_epoch_values(self, epoch: int): + return None + + def get_update_values(self, num_updates: int): + return None + + def step(self, epoch: int, metric: float = None) -> None: + self.metric = metric + values = self.get_epoch_values(epoch) + if values is not None: + values = self._add_noise(values, epoch) + self.update_groups(values) + + def step_update(self, num_updates: int, metric: float = None): + self.metric = metric + values = self.get_update_values(num_updates) + if values is not None: + values = self._add_noise(values, num_updates) + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + if 'lr_scale' in param_group: + param_group[self.param_group_field] = value * param_group['lr_scale'] + else: + param_group[self.param_group_field] = value + + def _add_noise(self, lrs, t): + if self._is_apply_noise(t): + noise = self._calculate_noise(t) + lrs = [v + v * noise for v in lrs] + return lrs + + def _is_apply_noise(self, t) -> bool: + """Return True if scheduler in noise range.""" + apply_noise = False + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + return apply_noise + + def _calculate_noise(self, t) -> float: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + return noise + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + return noise diff --git a/src/custom_timm/scheduler/scheduler_factory.py b/src/custom_timm/scheduler/scheduler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..3e100fe029c3bc2405d3cae0695376603dd78618 --- /dev/null +++ b/src/custom_timm/scheduler/scheduler_factory.py @@ -0,0 +1,107 @@ +""" Scheduler Factory +Hacked together by / Copyright 2021 Ross Wightman +""" +from .cosine_lr import CosineLRScheduler +from .multistep_lr import MultiStepLRScheduler +from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler + + +def create_scheduler(args, optimizer): + num_epochs = args.epochs + + if getattr(args, 'lr_noise', None) is not None: + lr_noise = getattr(args, 'lr_noise') + if isinstance(lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in lr_noise] + if len(noise_range) == 1: + noise_range = noise_range[0] + else: + noise_range = lr_noise * num_epochs + else: + noise_range = None + noise_args = dict( + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + cycle_args = dict( + cycle_mul=getattr(args, 'lr_cycle_mul', 1.), + cycle_decay=getattr(args, 'lr_cycle_decay', 0.1), + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + ) + + lr_scheduler = None + if args.sched == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + k_decay=getattr(args, 'lr_k_decay', 1.0), + **cycle_args, + **noise_args, + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=num_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + t_in_epochs=True, + **cycle_args, + **noise_args, + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + **noise_args, + ) + elif args.sched == 'multistep': + lr_scheduler = MultiStepLRScheduler( + optimizer, + decay_t=args.decay_milestones, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + **noise_args, + ) + elif args.sched == 'plateau': + mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=args.decay_rate, + patience_t=args.patience_epochs, + lr_min=args.min_lr, + mode=mode, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cooldown_t=0, + **noise_args, + ) + elif args.sched == 'poly': + lr_scheduler = PolyLRScheduler( + optimizer, + power=args.decay_rate, # overloading 'decay_rate' as polynomial power + t_initial=num_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + k_decay=getattr(args, 'lr_k_decay', 1.0), + **cycle_args, + **noise_args, + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + + return lr_scheduler, num_epochs diff --git a/src/custom_timm/scheduler/step_lr.py b/src/custom_timm/scheduler/step_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..f797e1a8cf35999531dd5f1ccbbe09a9d0cf30a9 --- /dev/null +++ b/src/custom_timm/scheduler/step_lr.py @@ -0,0 +1,63 @@ +""" Step Scheduler + +Basic step LR schedule with warmup, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import torch + +from .scheduler import Scheduler + + +class StepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: float, + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/src/custom_timm/scheduler/tanh_lr.py b/src/custom_timm/scheduler/tanh_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..f2d3c9cdb11ad31766062f1a8d3e69d3f845edc1 --- /dev/null +++ b/src/custom_timm/scheduler/tanh_lr.py @@ -0,0 +1,117 @@ +""" TanH Scheduler + +TanH schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class TanhLRScheduler(Scheduler): + """ + Hyberbolic-Tangent decay with restarts. + This is described in the paper https://arxiv.org/abs/1806.01593 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lb: float = -7., + ub: float = 3., + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + assert lb < ub + assert cycle_limit >= 0 + assert warmup_t >= 0 + assert warmup_lr_init >= 0 + self.lb = lb + self.ub = ub + self.t_initial = t_initial + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + if i < self.cycle_limit: + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + + tr = t_curr / t_i + lrs = [ + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/src/custom_timm/utils/__init__.py b/src/custom_timm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b139852d79644f97de7cf373a1a4c3dbd17f050 --- /dev/null +++ b/src/custom_timm/utils/__init__.py @@ -0,0 +1,14 @@ +from .agc import adaptive_clip_grad +from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad +from .cuda import ApexScaler, NativeScaler +from .decay_batch import decay_batch_step, check_batch_size_retry +from .distributed import distribute_bn, reduce_tensor +from .jit import set_jit_legacy, set_jit_fuser +from .log import setup_default_logging, FormatterNoInfo +from .metrics import AverageMeter, accuracy +from .misc import natural_key, add_bool_arg +from .model import unwrap_model, get_state_dict, freeze, unfreeze +from .model_ema import ModelEma, ModelEmaV2 +from .random import random_seed +from .summary import update_summary, get_outdir diff --git a/src/custom_timm/utils/agc.py b/src/custom_timm/utils/agc.py new file mode 100644 index 0000000000000000000000000000000000000000..f51401726ff6810d97d0fa567f4e31b474325a59 --- /dev/null +++ b/src/custom_timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/src/custom_timm/utils/checkpoint_saver.py b/src/custom_timm/utils/checkpoint_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..6aad74ee52655f68220f799efaffcbccdd0748ad --- /dev/null +++ b/src/custom_timm/utils/checkpoint_saver.py @@ -0,0 +1,150 @@ +""" Checkpoint Saver + +Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import glob +import operator +import os +import logging + +import torch + +from .model import unwrap_model, get_state_dict + + +_logger = logging.getLogger(__name__) + + +class CheckpointSaver: + def __init__( + self, + model, + optimizer, + args=None, + model_ema=None, + amp_scaler=None, + checkpoint_prefix='checkpoint', + recovery_prefix='recovery', + checkpoint_dir='', + recovery_dir='', + decreasing=False, + max_history=10, + unwrap_fn=unwrap_model): + + # objects to save state_dicts of + self.model = model + self.optimizer = optimizer + self.args = args + self.model_ema = model_ema + self.amp_scaler = amp_scaler + + # state + self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness + self.best_epoch = None + self.best_metric = None + self.curr_recovery_file = '' + self.last_recovery_file = '' + + # config + self.checkpoint_dir = checkpoint_dir + self.recovery_dir = recovery_dir + self.save_prefix = checkpoint_prefix + self.recovery_prefix = recovery_prefix + self.extension = '.pth.tar' + self.decreasing = decreasing # a lower metric is better if True + self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs + self.max_history = max_history + self.unwrap_fn = unwrap_fn + assert self.max_history >= 1 + + def save_checkpoint(self, epoch, metric=None): + assert epoch >= 0 + tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) + last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) + self._save(tmp_save_path, epoch, metric) + if os.path.exists(last_save_path): + os.unlink(last_save_path) # required for Windows support. + os.rename(tmp_save_path, last_save_path) + worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None + if (len(self.checkpoint_files) < self.max_history + or metric is None or self.cmp(metric, worst_file[1])): + if len(self.checkpoint_files) >= self.max_history: + self._cleanup_checkpoints(1) + filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension + save_path = os.path.join(self.checkpoint_dir, filename) + os.link(last_save_path, save_path) + self.checkpoint_files.append((save_path, metric)) + self.checkpoint_files = sorted( + self.checkpoint_files, key=lambda x: x[1], + reverse=not self.decreasing) # sort in descending order if a lower metric is not better + + checkpoints_str = "Current checkpoints:\n" + for c in self.checkpoint_files: + checkpoints_str += ' {}\n'.format(c) + _logger.info(checkpoints_str) + + if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): + self.best_epoch = epoch + self.best_metric = metric + best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) + if os.path.exists(best_save_path): + os.unlink(best_save_path) + os.link(last_save_path, best_save_path) + + return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) + + def _save(self, save_path, epoch, metric=None): + save_state = { + 'epoch': epoch, + 'arch': type(self.model).__name__.lower(), + 'state_dict': get_state_dict(self.model, self.unwrap_fn), + 'optimizer': self.optimizer.state_dict(), + 'version': 2, # version < 2 increments epoch before save + } + if self.args is not None: + save_state['arch'] = self.args.model + save_state['args'] = self.args + if self.amp_scaler is not None: + save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + if self.model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) + if metric is not None: + save_state['metric'] = metric + torch.save(save_state, save_path) + + def _cleanup_checkpoints(self, trim=0): + trim = min(len(self.checkpoint_files), trim) + delete_index = self.max_history - trim + if delete_index < 0 or len(self.checkpoint_files) <= delete_index: + return + to_delete = self.checkpoint_files[delete_index:] + for d in to_delete: + try: + _logger.debug("Cleaning checkpoint: {}".format(d)) + os.remove(d[0]) + except Exception as e: + _logger.error("Exception '{}' while deleting checkpoint".format(e)) + self.checkpoint_files = self.checkpoint_files[:delete_index] + + def save_recovery(self, epoch, batch_idx=0): + assert epoch >= 0 + filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension + save_path = os.path.join(self.recovery_dir, filename) + self._save(save_path, epoch) + if os.path.exists(self.last_recovery_file): + try: + _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) + os.remove(self.last_recovery_file) + except Exception as e: + _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) + self.last_recovery_file = self.curr_recovery_file + self.curr_recovery_file = save_path + + def find_recovery(self): + recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) + files = glob.glob(recovery_path + '*' + self.extension) + files = sorted(files) + return files[0] if len(files) else '' diff --git a/src/custom_timm/utils/clip_grad.py b/src/custom_timm/utils/clip_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..73671d3a5d2ad856630ce2b2d7b0d6e6e627c59a --- /dev/null +++ b/src/custom_timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from custom_timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/src/custom_timm/utils/cuda.py b/src/custom_timm/utils/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7bddf30463a7be7186c7def47c4e4dfb9993aa --- /dev/null +++ b/src/custom_timm/utils/cuda.py @@ -0,0 +1,55 @@ +""" CUDA / AMP utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +try: + from apex import amp + has_apex = True +except ImportError: + amp = None + has_apex = False + +from .clip_grad import dispatch_clip_grad + + +class ApexScaler: + state_dict_key = "amp" + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward(create_graph=create_graph) + if clip_grad is not None: + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) + optimizer.step() + + def state_dict(self): + if 'state_dict' in amp.__dict__: + return amp.state_dict() + + def load_state_dict(self, state_dict): + if 'load_state_dict' in amp.__dict__: + amp.load_state_dict(state_dict) + + +class NativeScaler: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + self._scaler.scale(loss).backward(create_graph=create_graph) + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) + self._scaler.step(optimizer) + self._scaler.update() + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) diff --git a/src/custom_timm/utils/decay_batch.py b/src/custom_timm/utils/decay_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..852fa4b8dc3d46932b67ed3e42170a5de92415d9 --- /dev/null +++ b/src/custom_timm/utils/decay_batch.py @@ -0,0 +1,43 @@ +""" Batch size decay and retry helpers. + +Copyright 2022 Ross Wightman +""" +import math + + +def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False): + """ power of two batch-size decay with intra steps + + Decay by stepping between powers of 2: + * determine power-of-2 floor of current batch size (base batch size) + * divide above value by num_intra_steps to determine step size + * floor batch_size to nearest multiple of step_size (from base batch size) + Examples: + num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1 + num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2 + num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1 + num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1 + """ + if batch_size <= 1: + # return 0 for stopping value so easy to use in loop + return 0 + base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2))) + step_size = max(base_batch_size // num_intra_steps, 1) + batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size + if no_odd and batch_size % 2: + batch_size -= 1 + return batch_size + + +def check_batch_size_retry(error_str): + """ check failure error string for conditions where batch decay retry should not be attempted + """ + error_str = error_str.lower() + if 'required rank' in error_str: + # Errors involving phrase 'required rank' typically happen when a conv is used that's + # not compatible with channels_last memory format. + return False + if 'illegal' in error_str: + # 'Illegal memory access' errors in CUDA typically leave process in unusable state + return False + return True diff --git a/src/custom_timm/utils/distributed.py b/src/custom_timm/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dba8c1de5a6ff53638207521377fdfbc4f239 --- /dev/null +++ b/src/custom_timm/utils/distributed.py @@ -0,0 +1,28 @@ +""" Distributed training/validation utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import distributed as dist + +from .model import unwrap_model + + +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= n + return rt + + +def distribute_bn(model, world_size, reduce=False): + # ensure every node has the same running bn stats + for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): + if ('running_mean' in bn_name) or ('running_var' in bn_name): + if reduce: + # average bn stats across whole group + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + else: + # broadcast bn stats from rank 0 to whole group + torch.distributed.broadcast(bn_buf, 0) diff --git a/src/custom_timm/utils/jit.py b/src/custom_timm/utils/jit.py new file mode 100644 index 0000000000000000000000000000000000000000..d527411fd3e1985639bb0b161bd484142a3619dd --- /dev/null +++ b/src/custom_timm/utils/jit.py @@ -0,0 +1,58 @@ +""" JIT scripting/tracing utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os + +import torch + + +def set_jit_legacy(): + """ Set JIT executor to legacy w/ support for op fusion + This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes + in the JIT exectutor. These API are not supported so could change. + """ + # + assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + #torch._C._jit_set_texpr_fuser_enabled(True) + + +def set_jit_fuser(fuser): + if fuser == "te": + # default fuser should be == 'te' + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(True) + try: + torch._C._jit_set_nvfuser_enabled(False) + except Exception: + pass + elif fuser == "old" or fuser == "legacy": + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(False) + try: + torch._C._jit_set_nvfuser_enabled(False) + except Exception: + pass + elif fuser == "nvfuser" or fuser == "nvf": + os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' + #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' + #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_can_fuse_on_cpu() + torch._C._jit_can_fuse_on_gpu() + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_nvfuser_guard_mode(True) + torch._C._jit_set_nvfuser_enabled(True) + else: + assert False, f"Invalid jit fuser ({fuser})" diff --git a/src/custom_timm/utils/log.py b/src/custom_timm/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..c99469e0884f3e45905ef7c7f0d1e491092697ad --- /dev/null +++ b/src/custom_timm/utils/log.py @@ -0,0 +1,28 @@ +""" Logging helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import logging.handlers + + +class FormatterNoInfo(logging.Formatter): + def __init__(self, fmt='%(levelname)s: %(message)s'): + logging.Formatter.__init__(self, fmt) + + def format(self, record): + if record.levelno == logging.INFO: + return str(record.getMessage()) + return logging.Formatter.format(self, record) + + +def setup_default_logging(default_level=logging.INFO, log_path=''): + console_handler = logging.StreamHandler() + console_handler.setFormatter(FormatterNoInfo()) + logging.root.addHandler(console_handler) + logging.root.setLevel(default_level) + if log_path: + file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) + file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") + file_handler.setFormatter(file_formatter) + logging.root.addHandler(file_handler) diff --git a/src/custom_timm/utils/metrics.py b/src/custom_timm/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdbe13ef15c541679906239374ff8a7eedf5181 --- /dev/null +++ b/src/custom_timm/utils/metrics.py @@ -0,0 +1,32 @@ +""" Eval metrics and related + +Hacked together by / Copyright 2020 Ross Wightman +""" + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = min(max(topk), output.size()[1]) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] diff --git a/src/custom_timm/utils/misc.py b/src/custom_timm/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..39c0097c60ed602547f832f1f8dafbe37f156064 --- /dev/null +++ b/src/custom_timm/utils/misc.py @@ -0,0 +1,18 @@ +""" Misc utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import re + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def add_bool_arg(parser, name, default=False, help=''): + dest_name = name.replace('-', '_') + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) + group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) + parser.set_defaults(**{dest_name: default}) diff --git a/src/custom_timm/utils/model.py b/src/custom_timm/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b95c45392bfb551f52bc8b8dca1aaf8c8b1940b1 --- /dev/null +++ b/src/custom_timm/utils/model.py @@ -0,0 +1,273 @@ +""" Model / state_dict utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import fnmatch + +import torch +from torchvision.ops.misc import FrozenBatchNorm2d + +from .model_ema import ModelEma + + +def unwrap_model(model): + if isinstance(model, ModelEma): + return unwrap_model(model.ema) + else: + return model.module if hasattr(model, 'module') else model + + +def get_state_dict(model, unwrap_fn=unwrap_model): + return unwrap_fn(model).state_dict() + + +def avg_sq_ch_mean(model, input, output): + """ calculate average channel square mean of output activations + """ + return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item() + + +def avg_ch_var(model, input, output): + """ calculate average channel variance of output activations + """ + return torch.mean(output.var(axis=[0, 2, 3])).item() + + +def avg_ch_var_residual(model, input, output): + """ calculate average channel variance of output activations + """ + return torch.mean(output.var(axis=[0, 2, 3])).item() + + +class ActivationStatsHook: + """Iterates through each of `model`'s modules and matches modules using unix pattern + matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is + a match. + + Arguments: + model (nn.Module): model from which we will extract the activation stats + hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string + matching with the name of model's modules. + hook_fns (List[Callable]): List of hook functions to be registered at every + module in `layer_names`. + + Inspiration from https://docs.fast.ai/callback.hook.html. + + Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example + on how to plot Signal Propogation Plots using `ActivationStatsHook`. + """ + + def __init__(self, model, hook_fn_locs, hook_fns): + self.model = model + self.hook_fn_locs = hook_fn_locs + self.hook_fns = hook_fns + if len(hook_fn_locs) != len(hook_fns): + raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ + their lengths are different.") + self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) + for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): + self.register_hook(hook_fn_loc, hook_fn) + + def _create_hook(self, hook_fn): + def append_activation_stats(module, input, output): + out = hook_fn(module, input, output) + self.stats[hook_fn.__name__].append(out) + + return append_activation_stats + + def register_hook(self, hook_fn_loc, hook_fn): + for name, module in self.model.named_modules(): + if not fnmatch.fnmatch(name, hook_fn_loc): + continue + module.register_forward_hook(self._create_hook(hook_fn)) + + +def extract_spp_stats( + model, + hook_fn_locs, + hook_fns, + input_shape=[8, 3, 224, 224]): + """Extract average square channel mean and variance of activations during + forward pass to plot Signal Propogation Plots (SPP). + + Paper: https://arxiv.org/abs/2101.08692 + + Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 + """ + x = torch.normal(0., 1., input_shape) + hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) + _ = model(x) + return hook.stats + + +def freeze_batch_norm_2d(module): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = freeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def unfreeze_batch_norm_2d(module): + """ + Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance + of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked + recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = unfreeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): + """ + Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is + done in place. + Args: + root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be (un)frozen. Defaults to [] + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers. + Defaults to `True`. + mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`. + """ + assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' + + if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + # Raise assertion here because we can't convert it in place + raise AssertionError( + "You have provided a batch norm layer as the `root module`. Please use " + "`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.") + + if isinstance(submodules, str): + submodules = [submodules] + + named_modules = submodules + submodules = [root_module.get_submodule(m) for m in submodules] + + if not len(submodules): + named_modules, submodules = list(zip(*root_module.named_children())) + + for n, m in zip(named_modules, submodules): + # (Un)freeze parameters + for p in m.parameters(): + p.requires_grad = False if mode == 'freeze' else True + if include_bn_running_stats: + # Helper to add submodule specified as a named_module + def _add_submodule(module, name, submodule): + split = name.rsplit('.', 1) + if len(split) > 1: + module.get_submodule(split[0]).add_module(split[1], submodule) + else: + module.add_module(name, submodule) + + # Freeze batch norm + if mode == 'freeze': + res = freeze_batch_norm_2d(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't + # convert it in place, but will return the converted result. In this case `res` holds the converted + # result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + _add_submodule(root_module, n, res) + # Unfreeze batch norm + else: + res = unfreeze_batch_norm_2d(m) + # Ditto. See note above in mode == 'freeze' branch + if isinstance(m, FrozenBatchNorm2d): + _add_submodule(root_module, n, res) + + +def freeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be frozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and + `SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning, + it's good practice to freeze batch norm stats. And note that these are different to the affine parameters + which are just normal PyTorch parameters. Defaults to `True`. + + Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`. + + Examples:: + + >>> model = timm.create_model('resnet18') + >>> # Freeze up to and including layer2 + >>> submodules = [n for n, _ in model.named_children()] + >>> print(submodules) + ['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc'] + >>> freeze(model, submodules[:submodules.index('layer2') + 1]) + >>> # Check for yourself that it works as expected + >>> print(model.layer2[0].conv1.weight.requires_grad) + False + >>> print(model.layer3[0].conv1.weight.requires_grad) + True + >>> # Unfreeze + >>> unfreeze(model) + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze") + + +def unfreeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided + as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty + list means that the whole root module will be unfrozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers. + These will be converted to `BatchNorm2d` in place. Defaults to `True`. + + See example in docstring for `freeze`. + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") diff --git a/src/custom_timm/utils/model_ema.py b/src/custom_timm/utils/model_ema.py new file mode 100644 index 0000000000000000000000000000000000000000..073d5c5ea1a4afc5aa3817b6354b2566f8cc2cf5 --- /dev/null +++ b/src/custom_timm/utils/model_ema.py @@ -0,0 +1,126 @@ +""" Exponential Moving Average (EMA) of model updates + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn + +_logger = logging.getLogger(__name__) + + +class ModelEma: + """ Model Exponential Moving Average (DEPRECATED) + + Keep a moving average of everything in the model state_dict (parameters and buffers). + This version is deprecated, it does not work with scripted models. Will be removed eventually. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device='', resume=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + self.ema_has_module = hasattr(self.ema, 'module') + if resume: + self._load_checkpoint(resume) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) + if 'state_dict_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict_ema'].items(): + # ema model may have been wrapped by DataParallel, and need module prefix + if self.ema_has_module: + name = 'module.' + k if not k.startswith('module') else k + else: + name = k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + _logger.info("Loaded state_dict_ema") + else: + _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + + def update(self, model): + # correct a mismatch in state dict keys + needs_module = hasattr(model, 'module') and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + if needs_module: + k = 'module.' + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class ModelEmaV2(nn.Module): + """ Model Exponential Moving Average V2 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device=None): + super(ModelEmaV2, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) diff --git a/src/custom_timm/utils/random.py b/src/custom_timm/utils/random.py new file mode 100644 index 0000000000000000000000000000000000000000..a9679983e96a9a6634c0b77aaf7b996e70eff50b --- /dev/null +++ b/src/custom_timm/utils/random.py @@ -0,0 +1,9 @@ +import random +import numpy as np +import torch + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) diff --git a/src/custom_timm/utils/summary.py b/src/custom_timm/utils/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5af9a08598556c3fed136f258f88bd578c1e1c --- /dev/null +++ b/src/custom_timm/utils/summary.py @@ -0,0 +1,39 @@ +""" Summary utilities + +Hacked together by / Copyright 2020 Ross Wightman +""" +import csv +import os +from collections import OrderedDict +try: + import wandb +except ImportError: + pass + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + + +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): + rowd = OrderedDict(epoch=epoch) + rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if log_wandb: + wandb.log(rowd) + with open(filename, mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if write_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd)