diff --git a/proard/__init__.py b/proard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/proard/classification/__init__.py b/proard/classification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/proard/classification/data_providers/__init__.py b/proard/classification/data_providers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e0ff93f45f1edcc8ed46c5c3814526264600b1 --- /dev/null +++ b/proard/classification/data_providers/__init__.py @@ -0,0 +1,3 @@ +from .cifar10 import * +from .cifar100 import * +from .imagenet import * \ No newline at end of file diff --git a/proard/classification/data_providers/base_provider.py b/proard/classification/data_providers/base_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd6ff040bbc02655c775b008b9ecad0487123db --- /dev/null +++ b/proard/classification/data_providers/base_provider.py @@ -0,0 +1,58 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import numpy as np +import torch + +__all__ = ["DataProvider"] + + +class DataProvider: + SUB_SEED = 937162211 # random seed for sampling subset + VALID_SEED = 2147483647 # random seed for the validation set + + @staticmethod + def name(): + """Return name of the dataset""" + raise NotImplementedError + + @property + def data_shape(self): + """Return shape as python list of one data entry""" + raise NotImplementedError + + @property + def n_classes(self): + """Return `int` of num classes""" + raise NotImplementedError + + @property + def save_path(self): + """local path to save the data""" + raise NotImplementedError + + @property + def data_url(self): + """link to download the data""" + raise NotImplementedError + + @staticmethod + def random_sample_valid_set(train_size, valid_size): + assert train_size > valid_size + + g = torch.Generator() + g.manual_seed( + DataProvider.VALID_SEED + ) # set random seed before sampling validation set + rand_indexes = torch.randperm(train_size, generator=g).tolist() + + valid_indexes = rand_indexes[:valid_size] + train_indexes = rand_indexes[valid_size:] + return train_indexes, valid_indexes + + @staticmethod + def labels_to_one_hot(n_classes, labels): + new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32) + new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape) + return new_labels \ No newline at end of file diff --git a/proard/classification/data_providers/cifar10.py b/proard/classification/data_providers/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..ac328f343e76a3249e40056f5c64f6bf949db3e7 --- /dev/null +++ b/proard/classification/data_providers/cifar10.py @@ -0,0 +1,264 @@ +import warnings +import os +import math +import numpy as np +import torch.utils.data +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from .base_provider import DataProvider +from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler + +__all__ = ["Cifar10DataProvider"] + +class Cifar10DataProvider(DataProvider): + DEFAULT_PATH = "./dataset/cifar10" + def __init__( + self, + save_path=None, + train_batch_size=256, + test_batch_size=512, + valid_size=None, + resize_scale=0.08, + distort_color=None, + n_worker=32, + image_size=32, + num_replicas=None, + rank=None, + ): + + warnings.filterwarnings("ignore") + self._save_path = save_path + + self.image_size = image_size # int or list of int + + + self._valid_transform_dict = {} + if not isinstance(self.image_size, int): + from proard.utils.my_dataloader.my_data_loader import MyDataLoader + + assert isinstance(self.image_size, list) + self.image_size.sort() # e.g., 160 -> 224 + MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() + MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) + + for img_size in self.image_size: + self._valid_transform_dict[img_size] = self.build_valid_transform( + img_size + ) + self.active_img_size = max(self.image_size) # active resolution for test + valid_transforms = self._valid_transform_dict[self.active_img_size] + train_loader_class = MyDataLoader # randomly sample image size for each batch of training image + else: + self.active_img_size = self.image_size + valid_transforms = self.build_valid_transform() + train_loader_class = torch.utils.data.DataLoader + + train_dataset = self.train_dataset(self.build_train_transform()) + + if valid_size is not None: + if not isinstance(valid_size, int): + assert isinstance(valid_size, float) and 0 < valid_size < 1 + valid_size = int(len(train_dataset) * valid_size) + + valid_dataset = self.train_dataset(valid_transforms) + train_indexes, valid_indexes = self.random_sample_valid_set( + len(train_dataset), valid_size + ) + + if num_replicas is not None: + train_sampler = MyDistributedSampler( + train_dataset, num_replicas, rank, True, np.array(train_indexes) + ) + valid_sampler = MyDistributedSampler( + valid_dataset, num_replicas, rank, True, np.array(valid_indexes) + ) + else: + train_sampler = torch.utils.data.sampler.SubsetRandomSampler( + train_indexes + ) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler( + valid_indexes + ) + + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + sampler=train_sampler, + num_workers=n_worker, + pin_memory=False, + ) + self.valid = torch.utils.data.DataLoader( + valid_dataset, + batch_size=test_batch_size, + sampler=valid_sampler, + num_workers=n_worker, + pin_memory=False, + ) + else: + if num_replicas is not None: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas, rank + ) + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + sampler=train_sampler, + num_workers=n_worker, + pin_memory=True, + ) + else: + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=False, + ) + self.valid = None + + test_dataset = self.test_dataset(valid_transforms) + if num_replicas is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, num_replicas, rank + ) + self.test = torch.utils.data.DataLoader( + test_dataset, + batch_size=test_batch_size, + sampler=test_sampler, + num_workers=n_worker, + pin_memory=False, + ) + else: + self.test = torch.utils.data.DataLoader( + test_dataset, + batch_size=test_batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=False, + ) + + if self.valid is None: + self.valid = self.test + + @staticmethod + def name(): + return "cifar10" + + @property + def data_shape(self): + return 3, self.active_img_size, self.active_img_size # C, H, W + + @property + def n_classes(self): + return 10 + + @property + def save_path(self): + if self._save_path is None: + self._save_path = self.DEFAULT_PATH + if not os.path.exists(self._save_path): + self._save_path = os.path.expanduser("~/dataset/cifar10") + return self._save_path + + @property + def data_url(self): + raise ValueError("unable to download %s" % self.name()) + + def train_dataset(self, _transforms): + return datasets.CIFAR10(self.train_path, train=True, transform=_transforms,download=True) + + def test_dataset(self, _transforms): + return datasets.CIFAR10(self.valid_path, train=False, transform=_transforms,download=True) + @property + def train_path(self): + return os.path.join(self.save_path, "train") + + @property + def valid_path(self): + return os.path.join(self.save_path, "val") + + @property + def normalize(self): + return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) + + def build_train_transform(self, image_size=None, print_log=True): + if image_size is None: + image_size = self.image_size + + # random_resize_crop -> random_horizontal_flip + train_transforms = [ + transforms.RandomCrop(32,padding=4), + transforms.RandomHorizontalFlip(), + # AutoAugment(), + ] + + train_transforms += [ + transforms.ToTensor(), + # self.normalize, + ] + + train_transforms = transforms.Compose(train_transforms) + return train_transforms + + def build_valid_transform(self, image_size=None): + if image_size is None: + image_size = self.active_img_size + return transforms.Compose([ + transforms.ToTensor(), + # self.normalize, + ]) + + def assign_active_img_size(self, new_img_size): + self.active_img_size = new_img_size + if self.active_img_size not in self._valid_transform_dict: + self._valid_transform_dict[ + self.active_img_size + ] = self.build_valid_transform() + # change the transform of the valid and test set + self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] + self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] + + def build_sub_train_loader( + self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None + ): + # used for resetting BN running statistics + if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None: + if num_worker is None: + num_worker = self.train.num_workers + + n_samples = len(self.train.dataset) + g = torch.Generator() + g.manual_seed(DataProvider.SUB_SEED) + rand_indexes = torch.randperm(n_samples, generator=g).tolist() + + new_train_dataset = self.train_dataset( + self.build_train_transform( + image_size=self.active_img_size, print_log=False + ) + ) + chosen_indexes = rand_indexes[:n_images] + if num_replicas is not None: + sub_sampler = MyDistributedSampler( + new_train_dataset, + num_replicas, + rank, + True, + np.array(chosen_indexes), + ) + else: + sub_sampler = torch.utils.data.sampler.SubsetRandomSampler( + chosen_indexes + ) + sub_data_loader = torch.utils.data.DataLoader( + new_train_dataset, + batch_size=batch_size, + sampler=sub_sampler, + num_workers=num_worker, + pin_memory=False, + ) + self.__dict__["sub_train_%d" % self.active_img_size] = [] + for images, labels in sub_data_loader: + self.__dict__["sub_train_%d" % self.active_img_size].append( + (images, labels) + ) + return self.__dict__["sub_train_%d" % self.active_img_size] diff --git a/proard/classification/data_providers/cifar100.py b/proard/classification/data_providers/cifar100.py new file mode 100644 index 0000000000000000000000000000000000000000..ff7423e63eb64866b149016b88c7feef18a5262b --- /dev/null +++ b/proard/classification/data_providers/cifar100.py @@ -0,0 +1,264 @@ +import warnings +import os +import math +import numpy as np +import torch.utils.data +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from .base_provider import DataProvider +from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler + +__all__ = ["Cifar100DataProvider"] + +class Cifar100DataProvider(DataProvider): + DEFAULT_PATH = "./dataset/cifar100" + def __init__( + self, + save_path=None, + train_batch_size=256, + test_batch_size=512, + resize_scale=0.08, + distort_color=None, + valid_size=None, + n_worker=32, + image_size=32, + num_replicas=None, + rank=None, + ): + + warnings.filterwarnings("ignore") + self._save_path = save_path + + self.image_size = image_size # int or list of int + + + self._valid_transform_dict = {} + if not isinstance(self.image_size, int): + from proard.utils.my_dataloader.my_data_loader import MyDataLoader + + assert isinstance(self.image_size, list) + self.image_size.sort() # e.g., 160 -> 224 + MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() + MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) + + for img_size in self.image_size: + self._valid_transform_dict[img_size] = self.build_valid_transform( + img_size + ) + self.active_img_size = max(self.image_size) # active resolution for test + valid_transforms = self._valid_transform_dict[self.active_img_size] + train_loader_class = MyDataLoader # randomly sample image size for each batch of training image + else: + self.active_img_size = self.image_size + valid_transforms = self.build_valid_transform() + train_loader_class = torch.utils.data.DataLoader + + train_dataset = self.train_dataset(self.build_train_transform()) + + if valid_size is not None: + if not isinstance(valid_size, int): + assert isinstance(valid_size, float) and 0 < valid_size < 1 + valid_size = int(len(train_dataset) * valid_size) + + valid_dataset = self.train_dataset(valid_transforms) + train_indexes, valid_indexes = self.random_sample_valid_set( + len(train_dataset), valid_size + ) + + if num_replicas is not None: + train_sampler = MyDistributedSampler( + train_dataset, num_replicas, rank, True, np.array(train_indexes) + ) + valid_sampler = MyDistributedSampler( + valid_dataset, num_replicas, rank, True, np.array(valid_indexes) + ) + else: + train_sampler = torch.utils.data.sampler.SubsetRandomSampler( + train_indexes + ) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler( + valid_indexes + ) + + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + sampler=train_sampler, + num_workers=n_worker, + pin_memory=False, + ) + self.valid = torch.utils.data.DataLoader( + valid_dataset, + batch_size=test_batch_size, + sampler=valid_sampler, + num_workers=n_worker, + pin_memory=False, + ) + else: + if num_replicas is not None: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas, rank + ) + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + sampler=train_sampler, + num_workers=n_worker, + pin_memory=True, + ) + else: + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=False, + ) + self.valid = None + + test_dataset = self.test_dataset(valid_transforms) + if num_replicas is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, num_replicas, rank + ) + self.test = torch.utils.data.DataLoader( + test_dataset, + batch_size=test_batch_size, + sampler=test_sampler, + num_workers=n_worker, + pin_memory=False, + ) + else: + self.test = torch.utils.data.DataLoader( + test_dataset, + batch_size=test_batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=False, + ) + + if self.valid is None: + self.valid = self.test + + @staticmethod + def name(): + return "cifar100" + + @property + def data_shape(self): + return 3, self.active_img_size, self.active_img_size # C, H, W + + @property + def n_classes(self): + return 100 + + @property + def save_path(self): + if self._save_path is None: + self._save_path = self.DEFAULT_PATH + if not os.path.exists(self._save_path): + self._save_path = os.path.expanduser("~/dataset/cifar100") + return self._save_path + + @property + def data_url(self): + raise ValueError("unable to download %s" % self.name()) + + def train_dataset(self, _transforms): + return datasets.CIFAR100(self.train_path, train=True, transform=_transforms,download=True) + + def test_dataset(self, _transforms): + return datasets.CIFAR100(self.valid_path, train=False, transform=_transforms,download=True) + @property + def train_path(self): + return os.path.join(self.save_path, "train") + + @property + def valid_path(self): + return os.path.join(self.save_path, "val") + + @property + def normalize(self): + return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) + + def build_train_transform(self, image_size=None, print_log=True): + if image_size is None: + image_size = self.image_size + + # random_resize_crop -> random_horizontal_flip + train_transforms = [ + transforms.RandomCrop(32,padding=4), + transforms.RandomHorizontalFlip(), + # AutoAugment(), + ] + + train_transforms += [ + transforms.ToTensor(), + # self.normalize, + ] + + train_transforms = transforms.Compose(train_transforms) + return train_transforms + + def build_valid_transform(self, image_size=None): + if image_size is None: + image_size = self.active_img_size + return transforms.Compose([ + transforms.ToTensor(), + # self.normalize, + ]) + + def assign_active_img_size(self, new_img_size): + self.active_img_size = new_img_size + if self.active_img_size not in self._valid_transform_dict: + self._valid_transform_dict[ + self.active_img_size + ] = self.build_valid_transform() + # change the transform of the valid and test set + self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] + self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] + + def build_sub_train_loader( + self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None + ): + # used for resetting BN running statistics + if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None: + if num_worker is None: + num_worker = self.train.num_workers + + n_samples = len(self.train.dataset) + g = torch.Generator() + g.manual_seed(DataProvider.SUB_SEED) + rand_indexes = torch.randperm(n_samples, generator=g).tolist() + + new_train_dataset = self.train_dataset( + self.build_train_transform( + image_size=self.active_img_size, print_log=False + ) + ) + chosen_indexes = rand_indexes[:n_images] + if num_replicas is not None: + sub_sampler = MyDistributedSampler( + new_train_dataset, + num_replicas, + rank, + True, + np.array(chosen_indexes), + ) + else: + sub_sampler = torch.utils.data.sampler.SubsetRandomSampler( + chosen_indexes + ) + sub_data_loader = torch.utils.data.DataLoader( + new_train_dataset, + batch_size=batch_size, + sampler=sub_sampler, + num_workers=num_worker, + pin_memory=False, + ) + self.__dict__["sub_train_%d" % self.active_img_size] = [] + for images, labels in sub_data_loader: + self.__dict__["sub_train_%d" % self.active_img_size].append( + (images, labels) + ) + return self.__dict__["sub_train_%d" % self.active_img_size] diff --git a/proard/classification/data_providers/imagenet.py b/proard/classification/data_providers/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1184b33b63887e6d10c29fb7064938f28f2d25 --- /dev/null +++ b/proard/classification/data_providers/imagenet.py @@ -0,0 +1,310 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import warnings +import os +import math +import numpy as np +import torch.utils.data +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from .base_provider import DataProvider +from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler + +__all__ = ["ImagenetDataProvider"] + + +class ImagenetDataProvider(DataProvider): + DEFAULT_PATH = "./dataset/imagenet" + + def __init__( + self, + save_path=None, + train_batch_size=256, + test_batch_size=512, + valid_size=None, + n_worker=32, + resize_scale=0.08, + distort_color=None, + image_size=224, + num_replicas=None, + rank=None, + ): + + warnings.filterwarnings("ignore") + self._save_path = save_path + + self.image_size = image_size # int or list of int + self.distort_color = "None" if distort_color is None else distort_color + self.resize_scale = resize_scale + + self._valid_transform_dict = {} + if not isinstance(self.image_size, int): + from proard.utils.my_dataloader.my_data_loader import MyDataLoader + + assert isinstance(self.image_size, list) + self.image_size.sort() # e.g., 160 -> 224 + MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() + MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) + + for img_size in self.image_size: + self._valid_transform_dict[img_size] = self.build_valid_transform( + img_size + ) + self.active_img_size = max(self.image_size) # active resolution for test + valid_transforms = self._valid_transform_dict[self.active_img_size] + train_loader_class = MyDataLoader # randomly sample image size for each batch of training image + else: + self.active_img_size = self.image_size + valid_transforms = self.build_valid_transform() + train_loader_class = torch.utils.data.DataLoader + + train_dataset = self.train_dataset(self.build_train_transform()) + + if valid_size is not None: + if not isinstance(valid_size, int): + assert isinstance(valid_size, float) and 0 < valid_size < 1 + valid_size = int(len(train_dataset) * valid_size) + + valid_dataset = self.train_dataset(valid_transforms) + train_indexes, valid_indexes = self.random_sample_valid_set( + len(train_dataset), valid_size + ) + + if num_replicas is not None: + train_sampler = MyDistributedSampler( + train_dataset, num_replicas, rank, True, np.array(train_indexes) + ) + valid_sampler = MyDistributedSampler( + valid_dataset, num_replicas, rank, True, np.array(valid_indexes) + ) + else: + train_sampler = torch.utils.data.sampler.SubsetRandomSampler( + train_indexes + ) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler( + valid_indexes + ) + + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + sampler=train_sampler, + num_workers=n_worker, + pin_memory=False, + ) + self.valid = torch.utils.data.DataLoader( + valid_dataset, + batch_size=test_batch_size, + sampler=valid_sampler, + num_workers=n_worker, + pin_memory=False, + ) + else: + if num_replicas is not None: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas, rank + ) + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + sampler=train_sampler, + num_workers=n_worker, + pin_memory=True, + ) + else: + self.train = train_loader_class( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=False, + ) + self.valid = None + + test_dataset = self.test_dataset(valid_transforms) + if num_replicas is not None: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, num_replicas, rank + ) + self.test = torch.utils.data.DataLoader( + test_dataset, + batch_size=test_batch_size, + sampler=test_sampler, + num_workers=n_worker, + pin_memory=False, + ) + else: + self.test = torch.utils.data.DataLoader( + test_dataset, + batch_size=test_batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=False, + ) + + if self.valid is None: + self.valid = self.test + + @staticmethod + def name(): + return "imagenet" + + @property + def data_shape(self): + return 3, self.active_img_size, self.active_img_size # C, H, W + + @property + def n_classes(self): + return 1000 + + @property + def save_path(self): + if self._save_path is None: + self._save_path = self.DEFAULT_PATH + if not os.path.exists(self._save_path): + self._save_path = os.path.expanduser("~/dataset/imagenet") + return self._save_path + + @property + def data_url(self): + raise ValueError("unable to download %s" % self.name()) + + def train_dataset(self, _transforms): + return datasets.ImageFolder(self.train_path, _transforms) + + def test_dataset(self, _transforms): + return datasets.ImageFolder(self.valid_path, _transforms) + + @property + def train_path(self): + return os.path.join(self.save_path, "train") + + @property + def valid_path(self): + return os.path.join(self.save_path, "val") + + @property + def normalize(self): + return transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + def build_train_transform(self, image_size=None, print_log=True): + if image_size is None: + image_size = self.image_size + if print_log: + print( + "Color jitter: %s, resize_scale: %s, img_size: %s" + % (self.distort_color, self.resize_scale, image_size) + ) + + if isinstance(image_size, list): + resize_transform_class = MyRandomResizedCrop + print( + "Use MyRandomResizedCrop: %s, \t %s" + % MyRandomResizedCrop.get_candidate_image_size(), + "sync=%s, continuous=%s" + % ( + MyRandomResizedCrop.SYNC_DISTRIBUTED, + MyRandomResizedCrop.CONTINUOUS, + ), + ) + else: + resize_transform_class = transforms.RandomResizedCrop + + # random_resize_crop -> random_horizontal_flip + train_transforms = [ + resize_transform_class(image_size, scale=(self.resize_scale, 1.0)), + transforms.RandomHorizontalFlip(), + ] + + # color augmentation (optional) + color_transform = None + if self.distort_color == "torch": + color_transform = transforms.ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 + ) + elif self.distort_color == "tf": + color_transform = transforms.ColorJitter( + brightness=32.0 / 255.0, saturation=0.5 + ) + if color_transform is not None: + train_transforms.append(color_transform) + + train_transforms += [ + transforms.ToTensor(), + self.normalize, + ] + + train_transforms = transforms.Compose(train_transforms) + return train_transforms + + def build_valid_transform(self, image_size=None): + if image_size is None: + image_size = self.active_img_size + return transforms.Compose( + [ + transforms.Resize(int(math.ceil(image_size / 0.875))), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + self.normalize, + ] + ) + + def assign_active_img_size(self, new_img_size): + self.active_img_size = new_img_size + if self.active_img_size not in self._valid_transform_dict: + self._valid_transform_dict[ + self.active_img_size + ] = self.build_valid_transform() + # change the transform of the valid and test set + self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] + self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] + + def build_sub_train_loader( + self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None + ): + # used for resetting BN running statistics + if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None: + if num_worker is None: + num_worker = self.train.num_workers + + n_samples = len(self.train.dataset) + g = torch.Generator() + g.manual_seed(DataProvider.SUB_SEED) + rand_indexes = torch.randperm(n_samples, generator=g).tolist() + + new_train_dataset = self.train_dataset( + self.build_train_transform( + image_size=self.active_img_size, print_log=False + ) + ) + chosen_indexes = rand_indexes[:n_images] + if num_replicas is not None: + sub_sampler = MyDistributedSampler( + new_train_dataset, + num_replicas, + rank, + True, + np.array(chosen_indexes), + ) + else: + sub_sampler = torch.utils.data.sampler.SubsetRandomSampler( + chosen_indexes + ) + sub_data_loader = torch.utils.data.DataLoader( + new_train_dataset, + batch_size=batch_size, + sampler=sub_sampler, + num_workers=num_worker, + pin_memory=False, + ) + self.__dict__["sub_train_%d" % self.active_img_size] = [] + for images, labels in sub_data_loader: + self.__dict__["sub_train_%d" % self.active_img_size].append( + (images, labels) + ) + return self.__dict__["sub_train_%d" % self.active_img_size] diff --git a/proard/classification/elastic_nn/__init__.py b/proard/classification/elastic_nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/proard/classification/elastic_nn/modules/__init__.py b/proard/classification/elastic_nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4173c1e75363ff6bf3755ddc291b626be292164c --- /dev/null +++ b/proard/classification/elastic_nn/modules/__init__.py @@ -0,0 +1,6 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .dynamic_layers import * +from .dynamic_op import * diff --git a/proard/classification/elastic_nn/modules/dynamic_layers.py b/proard/classification/elastic_nn/modules/dynamic_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f96ed498b6966b2cee8a17ef32ec7c22b92e71 --- /dev/null +++ b/proard/classification/elastic_nn/modules/dynamic_layers.py @@ -0,0 +1,841 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import copy +import torch +import torch.nn as nn +from collections import OrderedDict + +from proard.utils.layers import ( + MBConvLayer, + ConvLayer, + IdentityLayer, + set_layer_from_config, +) +from proard.utils.layers import ResNetBottleneckBlock, LinearLayer +from proard.utils import ( + MyModule, + val2list, + get_net_device, + build_activation, + make_divisible, + SEModule, + MyNetwork, +) +from .dynamic_op import ( + DynamicSeparableConv2d, + DynamicConv2d, + DynamicBatchNorm2d, + DynamicSE, + DynamicGroupNorm, +) +from .dynamic_op import DynamicLinear + +__all__ = [ + "adjust_bn_according_to_idx", + "copy_bn", + "DynamicMBConvLayer", + "DynamicConvLayer", + "DynamicLinearLayer", + "DynamicResNetBottleneckBlock", +] + + +def adjust_bn_according_to_idx(bn, idx): + bn.weight.data = torch.index_select(bn.weight.data, 0, idx) + bn.bias.data = torch.index_select(bn.bias.data, 0, idx) + if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]: + bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx) + bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx) + + +def copy_bn(target_bn, src_bn): + feature_dim = ( + target_bn.num_channels + if isinstance(target_bn, nn.GroupNorm) + else target_bn.num_features + ) + + target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim]) + target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim]) + if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]: + target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim]) + target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim]) + + +class DynamicLinearLayer(MyModule): + def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0): + super(DynamicLinearLayer, self).__init__() + + self.in_features_list = in_features_list + self.out_features = out_features + self.bias = bias + self.dropout_rate = dropout_rate + + if self.dropout_rate > 0: + self.dropout = nn.Dropout(self.dropout_rate, inplace=True) + else: + self.dropout = None + self.linear = DynamicLinear( + max_in_features=max(self.in_features_list), + max_out_features=self.out_features, + bias=self.bias, + ) + + def forward(self, x): + if self.dropout is not None: + x = self.dropout(x) + return self.linear(x) + + @property + def module_str(self): + return "DyLinear(%d, %d)" % (max(self.in_features_list), self.out_features) + + @property + def config(self): + return { + "name": DynamicLinear.__name__, + "in_features_list": self.in_features_list, + "out_features": self.out_features, + "bias": self.bias, + "dropout_rate": self.dropout_rate, + } + + @staticmethod + def build_from_config(config): + return DynamicLinearLayer(**config) + + def get_active_subnet(self, in_features, preserve_weight=True): + sub_layer = LinearLayer( + in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate + ) + sub_layer = sub_layer.to(get_net_device(self)) + if not preserve_weight: + return sub_layer + + sub_layer.linear.weight.data.copy_( + self.linear.get_active_weight(self.out_features, in_features).data + ) + if self.bias: + sub_layer.linear.bias.data.copy_( + self.linear.get_active_bias(self.out_features).data + ) + return sub_layer + + def get_active_subnet_config(self, in_features): + return { + "name": LinearLayer.__name__, + "in_features": in_features, + "out_features": self.out_features, + "bias": self.bias, + "dropout_rate": self.dropout_rate, + } + + +class DynamicMBConvLayer(MyModule): + def __init__( + self, + in_channel_list, + out_channel_list, + kernel_size_list=3, + expand_ratio_list=6, + stride=1, + act_func="relu6", + use_se=False, + ): + super(DynamicMBConvLayer, self).__init__() + + self.in_channel_list = in_channel_list + self.out_channel_list = out_channel_list + + self.kernel_size_list = val2list(kernel_size_list) + self.expand_ratio_list = val2list(expand_ratio_list) + + self.stride = stride + self.act_func = act_func + self.use_se = use_se + + # build modules + max_middle_channel = make_divisible( + round(max(self.in_channel_list) * max(self.expand_ratio_list)), + MyNetwork.CHANNEL_DIVISIBLE, + ) + if max(self.expand_ratio_list) == 1: + self.inverted_bottleneck = None + else: + self.inverted_bottleneck = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d( + max(self.in_channel_list), max_middle_channel + ), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func)), + ] + ) + ) + + self.depth_conv = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicSeparableConv2d( + max_middle_channel, self.kernel_size_list, self.stride + ), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func)), + ] + ) + ) + if self.use_se: + self.depth_conv.add_module("se", DynamicSE(max_middle_channel)) + + self.point_linear = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d(max_middle_channel, max(self.out_channel_list)), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + + self.active_kernel_size = max(self.kernel_size_list) + self.active_expand_ratio = max(self.expand_ratio_list) + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + in_channel = x.size(1) + + if self.inverted_bottleneck is not None: + self.inverted_bottleneck.conv.active_out_channel = make_divisible( + round(in_channel * self.active_expand_ratio), + MyNetwork.CHANNEL_DIVISIBLE, + ) + + self.depth_conv.conv.active_kernel_size = self.active_kernel_size + self.point_linear.conv.active_out_channel = self.active_out_channel + + if self.inverted_bottleneck is not None: + x = self.inverted_bottleneck(x) + x = self.depth_conv(x) + x = self.point_linear(x) + return x + + @property + def module_str(self): + if self.use_se: + return "SE(O%d, E%.1f, K%d)" % ( + self.active_out_channel, + self.active_expand_ratio, + self.active_kernel_size, + ) + else: + return "(O%d, E%.1f, K%d)" % ( + self.active_out_channel, + self.active_expand_ratio, + self.active_kernel_size, + ) + + @property + def config(self): + return { + "name": DynamicMBConvLayer.__name__, + "in_channel_list": self.in_channel_list, + "out_channel_list": self.out_channel_list, + "kernel_size_list": self.kernel_size_list, + "expand_ratio_list": self.expand_ratio_list, + "stride": self.stride, + "act_func": self.act_func, + "use_se": self.use_se, + } + + @staticmethod + def build_from_config(config): + return DynamicMBConvLayer(**config) + + ############################################################################################ + + @property + def in_channels(self): + return max(self.in_channel_list) + + @property + def out_channels(self): + return max(self.out_channel_list) + + def active_middle_channel(self, in_channel): + return make_divisible( + round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE + ) + + ############################################################################################ + + def get_active_subnet(self, in_channel, preserve_weight=True): + # build the new layer + sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) + sub_layer = sub_layer.to(get_net_device(self)) + if not preserve_weight: + return sub_layer + + middle_channel = self.active_middle_channel(in_channel) + # copy weight from current layer + if sub_layer.inverted_bottleneck is not None: + sub_layer.inverted_bottleneck.conv.weight.data.copy_( + self.inverted_bottleneck.conv.get_active_filter( + middle_channel, in_channel + ).data, + ) + copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn) + + sub_layer.depth_conv.conv.weight.data.copy_( + self.depth_conv.conv.get_active_filter( + middle_channel, self.active_kernel_size + ).data + ) + copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn) + + if self.use_se: + se_mid = make_divisible( + middle_channel // SEModule.REDUCTION, + divisor=MyNetwork.CHANNEL_DIVISIBLE, + ) + sub_layer.depth_conv.se.fc.reduce.weight.data.copy_( + self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data + ) + sub_layer.depth_conv.se.fc.reduce.bias.data.copy_( + self.depth_conv.se.get_active_reduce_bias(se_mid).data + ) + + sub_layer.depth_conv.se.fc.expand.weight.data.copy_( + self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data + ) + sub_layer.depth_conv.se.fc.expand.bias.data.copy_( + self.depth_conv.se.get_active_expand_bias(middle_channel).data + ) + + sub_layer.point_linear.conv.weight.data.copy_( + self.point_linear.conv.get_active_filter( + self.active_out_channel, middle_channel + ).data + ) + copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn) + + return sub_layer + + def get_active_subnet_config(self, in_channel): + return { + "name": MBConvLayer.__name__, + "in_channels": in_channel, + "out_channels": self.active_out_channel, + "kernel_size": self.active_kernel_size, + "stride": self.stride, + "expand_ratio": self.active_expand_ratio, + "mid_channels": self.active_middle_channel(in_channel), + "act_func": self.act_func, + "use_se": self.use_se, + } + + def re_organize_middle_weights(self, expand_ratio_stage=0): + importance = torch.sum( + torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3) + ) + if isinstance(self.depth_conv.bn, DynamicGroupNorm): + channel_per_group = self.depth_conv.bn.channel_per_group + importance_chunks = torch.split(importance, channel_per_group) + for chunk in importance_chunks: + chunk.data.fill_(torch.mean(chunk)) + importance = torch.cat(importance_chunks, dim=0) + if expand_ratio_stage > 0: + sorted_expand_list = copy.deepcopy(self.expand_ratio_list) + sorted_expand_list.sort(reverse=True) + target_width_list = [ + make_divisible( + round(max(self.in_channel_list) * expand), + MyNetwork.CHANNEL_DIVISIBLE, + ) + for expand in sorted_expand_list + ] + + right = len(importance) + base = -len(target_width_list) * 1e5 + for i in range(expand_ratio_stage + 1): + left = target_width_list[i] + importance[left:right] += base + base += 1e5 + right = left + + sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) + self.point_linear.conv.conv.weight.data = torch.index_select( + self.point_linear.conv.conv.weight.data, 1, sorted_idx + ) + + adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx) + self.depth_conv.conv.conv.weight.data = torch.index_select( + self.depth_conv.conv.conv.weight.data, 0, sorted_idx + ) + + if self.use_se: + # se expand: output dim 0 reorganize + se_expand = self.depth_conv.se.fc.expand + se_expand.weight.data = torch.index_select( + se_expand.weight.data, 0, sorted_idx + ) + se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx) + # se reduce: input dim 1 reorganize + se_reduce = self.depth_conv.se.fc.reduce + se_reduce.weight.data = torch.index_select( + se_reduce.weight.data, 1, sorted_idx + ) + # middle weight reorganize + se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3)) + se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True) + + se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx) + se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx) + se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx) + + if self.inverted_bottleneck is not None: + adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx) + self.inverted_bottleneck.conv.conv.weight.data = torch.index_select( + self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx + ) + return None + else: + return sorted_idx + + +class DynamicConvLayer(MyModule): + def __init__( + self, + in_channel_list, + out_channel_list, + kernel_size=3, + stride=1, + dilation=1, + use_bn=True, + act_func="relu6", + ): + super(DynamicConvLayer, self).__init__() + + self.in_channel_list = in_channel_list + self.out_channel_list = out_channel_list + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.use_bn = use_bn + self.act_func = act_func + + self.conv = DynamicConv2d( + max_in_channels=max(self.in_channel_list), + max_out_channels=max(self.out_channel_list), + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + ) + if self.use_bn: + self.bn = DynamicBatchNorm2d(max(self.out_channel_list)) + self.act = build_activation(self.act_func) + + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + self.conv.active_out_channel = self.active_out_channel + + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + x = self.act(x) + return x + + @property + def module_str(self): + return "DyConv(O%d, K%d, S%d)" % ( + self.active_out_channel, + self.kernel_size, + self.stride, + ) + + @property + def config(self): + return { + "name": DynamicConvLayer.__name__, + "in_channel_list": self.in_channel_list, + "out_channel_list": self.out_channel_list, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "use_bn": self.use_bn, + "act_func": self.act_func, + } + + @staticmethod + def build_from_config(config): + return DynamicConvLayer(**config) + + ############################################################################################ + + @property + def in_channels(self): + return max(self.in_channel_list) + + @property + def out_channels(self): + return max(self.out_channel_list) + + ############################################################################################ + + def get_active_subnet(self, in_channel, preserve_weight=True): + sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) + sub_layer = sub_layer.to(get_net_device(self)) + + if not preserve_weight: + return sub_layer + + sub_layer.conv.weight.data.copy_( + self.conv.get_active_filter(self.active_out_channel, in_channel).data + ) + if self.use_bn: + copy_bn(sub_layer.bn, self.bn.bn) + + return sub_layer + + def get_active_subnet_config(self, in_channel): + return { + "name": ConvLayer.__name__, + "in_channels": in_channel, + "out_channels": self.active_out_channel, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "use_bn": self.use_bn, + "act_func": self.act_func, + } + + +class DynamicResNetBottleneckBlock(MyModule): + def __init__( + self, + in_channel_list, + out_channel_list, + expand_ratio_list=0.25, + kernel_size=3, + stride=1, + act_func="relu", + downsample_mode="avgpool_conv", + ): + super(DynamicResNetBottleneckBlock, self).__init__() + + self.in_channel_list = in_channel_list + self.out_channel_list = out_channel_list + self.expand_ratio_list = val2list(expand_ratio_list) + + self.kernel_size = kernel_size + self.stride = stride + self.act_func = act_func + self.downsample_mode = downsample_mode + + # build modules + max_middle_channel = make_divisible( + round(max(self.out_channel_list) * max(self.expand_ratio_list)), + MyNetwork.CHANNEL_DIVISIBLE, + ) + + self.conv1 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d(max(self.in_channel_list), max_middle_channel), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + self.conv2 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d( + max_middle_channel, max_middle_channel, kernel_size, stride + ), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + self.conv3 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d(max_middle_channel, max(self.out_channel_list)), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + + if self.stride == 1 and self.in_channel_list == self.out_channel_list: + self.downsample = IdentityLayer( + max(self.in_channel_list), max(self.out_channel_list) + ) + elif self.downsample_mode == "conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d( + max(self.in_channel_list), + max(self.out_channel_list), + stride=stride, + ), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + elif self.downsample_mode == "avgpool_conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "avg_pool", + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + padding=0, + ceil_mode=True, + ), + ), + ( + "conv", + DynamicConv2d( + max(self.in_channel_list), max(self.out_channel_list) + ), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + else: + raise NotImplementedError + + self.final_act = build_activation(self.act_func, inplace=True) + + self.active_expand_ratio = max(self.expand_ratio_list) + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + feature_dim = self.active_middle_channels + + self.conv1.conv.active_out_channel = feature_dim + self.conv2.conv.active_out_channel = feature_dim + self.conv3.conv.active_out_channel = self.active_out_channel + if not isinstance(self.downsample, IdentityLayer): + self.downsample.conv.active_out_channel = self.active_out_channel + + residual = self.downsample(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + + x = x + residual + x = self.final_act(x) + return x + + @property + def module_str(self): + return "(%s, %s)" % ( + "%dx%d_BottleneckConv_in->%d->%d_S%d" + % ( + self.kernel_size, + self.kernel_size, + self.active_middle_channels, + self.active_out_channel, + self.stride, + ), + "Identity" + if isinstance(self.downsample, IdentityLayer) + else self.downsample_mode, + ) + + @property + def config(self): + return { + "name": DynamicResNetBottleneckBlock.__name__, + "in_channel_list": self.in_channel_list, + "out_channel_list": self.out_channel_list, + "expand_ratio_list": self.expand_ratio_list, + "kernel_size": self.kernel_size, + "stride": self.stride, + "act_func": self.act_func, + "downsample_mode": self.downsample_mode, + } + + @staticmethod + def build_from_config(config): + return DynamicResNetBottleneckBlock(**config) + + ############################################################################################ + + @property + def in_channels(self): + return max(self.in_channel_list) + + @property + def out_channels(self): + return max(self.out_channel_list) + + @property + def active_middle_channels(self): + feature_dim = round(self.active_out_channel * self.active_expand_ratio) + feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE) + return feature_dim + + ############################################################################################ + + def get_active_subnet(self, in_channel, preserve_weight=True): + # build the new layer + sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) + sub_layer = sub_layer.to(get_net_device(self)) + if not preserve_weight: + return sub_layer + + # copy weight from current layer + sub_layer.conv1.conv.weight.data.copy_( + self.conv1.conv.get_active_filter( + self.active_middle_channels, in_channel + ).data + ) + copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn) + + sub_layer.conv2.conv.weight.data.copy_( + self.conv2.conv.get_active_filter( + self.active_middle_channels, self.active_middle_channels + ).data + ) + copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn) + + sub_layer.conv3.conv.weight.data.copy_( + self.conv3.conv.get_active_filter( + self.active_out_channel, self.active_middle_channels + ).data + ) + copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn) + + if not isinstance(self.downsample, IdentityLayer): + sub_layer.downsample.conv.weight.data.copy_( + self.downsample.conv.get_active_filter( + self.active_out_channel, in_channel + ).data + ) + copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn) + + return sub_layer + + def get_active_subnet_config(self, in_channel): + return { + "name": ResNetBottleneckBlock.__name__, + "in_channels": in_channel, + "out_channels": self.active_out_channel, + "kernel_size": self.kernel_size, + "stride": self.stride, + "expand_ratio": self.active_expand_ratio, + "mid_channels": self.active_middle_channels, + "act_func": self.act_func, + "groups": 1, + "downsample_mode": self.downsample_mode, + } + + def re_organize_middle_weights(self, expand_ratio_stage=0): + # conv3 -> conv2 + importance = torch.sum( + torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3) + ) + if isinstance(self.conv2.bn, DynamicGroupNorm): + channel_per_group = self.conv2.bn.channel_per_group + importance_chunks = torch.split(importance, channel_per_group) + for chunk in importance_chunks: + chunk.data.fill_(torch.mean(chunk)) + importance = torch.cat(importance_chunks, dim=0) + if expand_ratio_stage > 0: + sorted_expand_list = copy.deepcopy(self.expand_ratio_list) + sorted_expand_list.sort(reverse=True) + target_width_list = [ + make_divisible( + round(max(self.out_channel_list) * expand), + MyNetwork.CHANNEL_DIVISIBLE, + ) + for expand in sorted_expand_list + ] + right = len(importance) + base = -len(target_width_list) * 1e5 + for i in range(expand_ratio_stage + 1): + left = target_width_list[i] + importance[left:right] += base + base += 1e5 + right = left + + sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) + self.conv3.conv.conv.weight.data = torch.index_select( + self.conv3.conv.conv.weight.data, 1, sorted_idx + ) + adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx) + self.conv2.conv.conv.weight.data = torch.index_select( + self.conv2.conv.conv.weight.data, 0, sorted_idx + ) + + # conv2 -> conv1 + importance = torch.sum( + torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3) + ) + if isinstance(self.conv1.bn, DynamicGroupNorm): + channel_per_group = self.conv1.bn.channel_per_group + importance_chunks = torch.split(importance, channel_per_group) + for chunk in importance_chunks: + chunk.data.fill_(torch.mean(chunk)) + importance = torch.cat(importance_chunks, dim=0) + if expand_ratio_stage > 0: + sorted_expand_list = copy.deepcopy(self.expand_ratio_list) + sorted_expand_list.sort(reverse=True) + target_width_list = [ + make_divisible( + round(max(self.out_channel_list) * expand), + MyNetwork.CHANNEL_DIVISIBLE, + ) + for expand in sorted_expand_list + ] + right = len(importance) + base = -len(target_width_list) * 1e5 + for i in range(expand_ratio_stage + 1): + left = target_width_list[i] + importance[left:right] += base + base += 1e5 + right = left + sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) + + self.conv2.conv.conv.weight.data = torch.index_select( + self.conv2.conv.conv.weight.data, 1, sorted_idx + ) + adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx) + self.conv1.conv.conv.weight.data = torch.index_select( + self.conv1.conv.conv.weight.data, 0, sorted_idx + ) + + return None diff --git a/proard/classification/elastic_nn/modules/dynamic_op.py b/proard/classification/elastic_nn/modules/dynamic_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e415b16c0e9bace804db6ff6d068f0abb5cc63e3 --- /dev/null +++ b/proard/classification/elastic_nn/modules/dynamic_op.py @@ -0,0 +1,401 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import torch.nn.functional as F +import torch.nn as nn +import torch +from torch.nn.parameter import Parameter + +from proard.utils import ( + get_same_padding, + sub_filter_start_end, + make_divisible, + SEModule, + MyNetwork, + MyConv2d, +) + +__all__ = [ + "DynamicSeparableConv2d", + "DynamicConv2d", + "DynamicGroupConv2d", + "DynamicBatchNorm2d", + "DynamicGroupNorm", + "DynamicSE", + "DynamicLinear", +] + +# Seprable conv consits of a depthwise and pointwise conv + +class DynamicSeparableConv2d(nn.Module): + KERNEL_TRANSFORM_MODE = 1 # None or 1 + + def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1): + super(DynamicSeparableConv2d, self).__init__() + + self.max_in_channels = max_in_channels + self.kernel_size_list = kernel_size_list # list of kernel size + self.stride = stride + self.dilation = dilation + + self.conv = nn.Conv2d( + self.max_in_channels, + self.max_in_channels, + max(self.kernel_size_list), + self.stride, + groups=self.max_in_channels, + bias=False, + ) + + self._ks_set = list(set(self.kernel_size_list)) + self._ks_set.sort() # e.g., [3, 5, 7] + # define a matrix for converting from damll kernel size to larger one + if self.KERNEL_TRANSFORM_MODE is not None: + # register scaling parameters + # 7to5_matrix, 5to3_matrix + scale_params = {} + for i in range(len(self._ks_set) - 1): + ks_small = self._ks_set[i] + ks_larger = self._ks_set[i + 1] + param_name = "%dto%d" % (ks_larger, ks_small) + # noinspection PyArgumentList + scale_params["%s_matrix" % param_name] = Parameter( + torch.eye(ks_small ** 2) + ) + for name, param in scale_params.items(): + self.register_parameter(name, param) + + self.active_kernel_size = max(self.kernel_size_list) + + def get_active_filter(self, in_channel, kernel_size): + out_channel = in_channel + max_kernel_size = max(self.kernel_size_list) + + start, end = sub_filter_start_end(max_kernel_size, kernel_size) + filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end] + if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size: + start_filter = self.conv.weight[ + :out_channel, :in_channel, :, : + ] # start with max kernel + for i in range(len(self._ks_set) - 1, 0, -1): + src_ks = self._ks_set[i] + if src_ks <= kernel_size: + break + target_ks = self._ks_set[i - 1] + start, end = sub_filter_start_end(src_ks, target_ks) + _input_filter = start_filter[:, :, start:end, start:end] + _input_filter = _input_filter.contiguous() + _input_filter = _input_filter.view( + _input_filter.size(0), _input_filter.size(1), -1 + ) + _input_filter = _input_filter.view(-1, _input_filter.size(2)) + _input_filter = F.linear( + _input_filter, + self.__getattr__("%dto%d_matrix" % (src_ks, target_ks)), + ) + _input_filter = _input_filter.view( + filters.size(0), filters.size(1), target_ks ** 2 + ) + _input_filter = _input_filter.view( + filters.size(0), filters.size(1), target_ks, target_ks + ) + start_filter = _input_filter + filters = start_filter + return filters + + def forward(self, x, kernel_size=None): + if kernel_size is None: + kernel_size = self.active_kernel_size + in_channel = x.size(1) + + filters = self.get_active_filter(in_channel, kernel_size).contiguous() + + padding = get_same_padding(kernel_size) + filters = ( + self.conv.weight_standardization(filters) + if isinstance(self.conv, MyConv2d) + else filters + ) + y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, in_channel) + return y + + +class DynamicConv2d(nn.Module): + def __init__( + self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1 + ): + super(DynamicConv2d, self).__init__() + + self.max_in_channels = max_in_channels + self.max_out_channels = max_out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + self.conv = nn.Conv2d( + self.max_in_channels, + self.max_out_channels, + self.kernel_size, + stride=self.stride, + bias=False, + ) + + self.active_out_channel = self.max_out_channels + + def get_active_filter(self, out_channel, in_channel): + return self.conv.weight[:out_channel, :in_channel, :, :] + + def forward(self, x, out_channel=None): + if out_channel is None: + out_channel = self.active_out_channel + in_channel = x.size(1) + filters = self.get_active_filter(out_channel, in_channel).contiguous() + + padding = get_same_padding(self.kernel_size) + filters = ( + self.conv.weight_standardization(filters) + if isinstance(self.conv, MyConv2d) + else filters + ) + y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1) + return y + + +class DynamicGroupConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size_list, + groups_list, + stride=1, + dilation=1, + ): + super(DynamicGroupConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size_list = kernel_size_list + self.groups_list = groups_list + self.stride = stride + self.dilation = dilation + + self.conv = nn.Conv2d( + self.in_channels, + self.out_channels, + max(self.kernel_size_list), + self.stride, + groups=min(self.groups_list), + bias=False, + ) + + self.active_kernel_size = max(self.kernel_size_list) + self.active_groups = min(self.groups_list) + + def get_active_filter(self, kernel_size, groups): + start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size) + filters = self.conv.weight[:, :, start:end, start:end] + + sub_filters = torch.chunk(filters, groups, dim=0) + sub_in_channels = self.in_channels // groups + sub_ratio = filters.size(1) // sub_in_channels + + filter_crops = [] + for i, sub_filter in enumerate(sub_filters): + part_id = i % sub_ratio + start = part_id * sub_in_channels + filter_crops.append(sub_filter[:, start : start + sub_in_channels, :, :]) + filters = torch.cat(filter_crops, dim=0) + return filters + + def forward(self, x, kernel_size=None, groups=None): + if kernel_size is None: + kernel_size = self.active_kernel_size + if groups is None: + groups = self.active_groups + + filters = self.get_active_filter(kernel_size, groups).contiguous() + padding = get_same_padding(kernel_size) + filters = ( + self.conv.weight_standardization(filters) + if isinstance(self.conv, MyConv2d) + else filters + ) + y = F.conv2d( + x, + filters, + None, + self.stride, + padding, + self.dilation, + groups, + ) + return y + + +class DynamicBatchNorm2d(nn.Module): + SET_RUNNING_STATISTICS = False + + def __init__(self, max_feature_dim): + super(DynamicBatchNorm2d, self).__init__() + + self.max_feature_dim = max_feature_dim + self.bn = nn.BatchNorm2d(self.max_feature_dim) + + @staticmethod + def bn_forward(x, bn: nn.BatchNorm2d, feature_dim): + if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS: + return bn(x) + else: + exponential_average_factor = 0.0 + + if bn.training and bn.track_running_stats: + if bn.num_batches_tracked is not None: + bn.num_batches_tracked += 1 + if bn.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(bn.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = bn.momentum + return F.batch_norm( + x, + bn.running_mean[:feature_dim], + bn.running_var[:feature_dim], + bn.weight[:feature_dim], + bn.bias[:feature_dim], + bn.training or not bn.track_running_stats, + exponential_average_factor, + bn.eps, + ) + + def forward(self, x): + feature_dim = x.size(1) + y = self.bn_forward(x, self.bn, feature_dim) + return y + + +class DynamicGroupNorm(nn.GroupNorm): + def __init__( + self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None + ): + super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + self.channel_per_group = channel_per_group + + def forward(self, x): + n_channels = x.size(1) + n_groups = n_channels // self.channel_per_group + return F.group_norm( + x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps + ) + + @property + def bn(self): + return self + + +class DynamicSE(SEModule): + def __init__(self, max_channel): + super(DynamicSE, self).__init__(max_channel) + + def get_active_reduce_weight(self, num_mid, in_channel, groups=None): + if groups is None or groups == 1: + return self.fc.reduce.weight[:num_mid, :in_channel, :, :] + else: + assert in_channel % groups == 0 + sub_in_channels = in_channel // groups + sub_filters = torch.chunk( + self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1 + ) + return torch.cat( + [sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters], + dim=1, + ) + + def get_active_reduce_bias(self, num_mid): + return ( + self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None + ) + + def get_active_expand_weight(self, num_mid, in_channel, groups=None): + if groups is None or groups == 1: + return self.fc.expand.weight[:in_channel, :num_mid, :, :] + else: + assert in_channel % groups == 0 + sub_in_channels = in_channel // groups + sub_filters = torch.chunk( + self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0 + ) + return torch.cat( + [sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters], + dim=0, + ) + + def get_active_expand_bias(self, in_channel, groups=None): + if groups is None or groups == 1: + return ( + self.fc.expand.bias[:in_channel] + if self.fc.expand.bias is not None + else None + ) + else: + assert in_channel % groups == 0 + sub_in_channels = in_channel // groups + sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0) + return torch.cat( + [sub_bias[:sub_in_channels] for sub_bias in sub_bias_list], dim=0 + ) + + def forward(self, x, groups=None): + in_channel = x.size(1) + num_mid = make_divisible( + in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE + ) + + y = x.mean(3, keepdim=True).mean(2, keepdim=True) + # reduce + reduce_filter = self.get_active_reduce_weight( + num_mid, in_channel, groups=groups + ).contiguous() + reduce_bias = self.get_active_reduce_bias(num_mid) + y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1) + # relu + y = self.fc.relu(y) + # expand + expand_filter = self.get_active_expand_weight( + num_mid, in_channel, groups=groups + ).contiguous() + expand_bias = self.get_active_expand_bias(in_channel, groups=groups) + y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1) + # hard sigmoid + y = self.fc.h_sigmoid(y) + + return x * y + + +class DynamicLinear(nn.Module): + def __init__(self, max_in_features, max_out_features, bias=True): + super(DynamicLinear, self).__init__() + + self.max_in_features = max_in_features + self.max_out_features = max_out_features + self.bias = bias + + self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias) + + self.active_out_features = self.max_out_features + + def get_active_weight(self, out_features, in_features): + return self.linear.weight[:out_features, :in_features] + + def get_active_bias(self, out_features): + return self.linear.bias[:out_features] if self.bias else None + + def forward(self, x, out_features=None): + if out_features is None: + out_features = self.active_out_features + + in_features = x.size(1) + weight = self.get_active_weight(out_features, in_features).contiguous() + bias = self.get_active_bias(out_features) + y = F.linear(x, weight, bias) + return y diff --git a/proard/classification/elastic_nn/networks/__init__.py b/proard/classification/elastic_nn/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..afc2aa6dab164ccde7285593e9d2aa245c21e8d0 --- /dev/null +++ b/proard/classification/elastic_nn/networks/__init__.py @@ -0,0 +1,7 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .dyn_proxyless import DYNProxylessNASNets,DYNProxylessNASNets_Cifar +from .dyn_mbv3 import DYNMobileNetV3,DYNMobileNetV3_Cifar +from .dyn_resnets import DYNResNets,DYNResNets_Cifar diff --git a/proard/classification/elastic_nn/networks/dyn_mbv3.py b/proard/classification/elastic_nn/networks/dyn_mbv3.py new file mode 100644 index 0000000000000000000000000000000000000000..539b182b3344b237e20a9ea76f55a324ee878702 --- /dev/null +++ b/proard/classification/elastic_nn/networks/dyn_mbv3.py @@ -0,0 +1,780 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import copy +import random + +from proard.classification.elastic_nn.modules.dynamic_layers import ( + DynamicMBConvLayer, +) +from proard.utils.layers import ( + ConvLayer, + IdentityLayer, + LinearLayer, + MBConvLayer, + ResidualBlock, +) +from proard.classification.networks import MobileNetV3,MobileNetV3_Cifar +from proard.utils import make_divisible, val2list, MyNetwork + +__all__ = ["DYNMobileNetV3","DYNMobileNetV3_Cifar"] + + +class DYNMobileNetV3(MobileNetV3): + def __init__( + self, + n_classes=1000, + bn_param=(0.1, 1e-5), + dropout_rate=0.1, + base_stage_width=None, + width_mult=1.0, + ks_list=3, + expand_ratio_list=6, + depth_list=4, + ): + + self.width_mult = width_mult + self.ks_list = val2list(ks_list, 1) + self.expand_ratio_list = val2list(expand_ratio_list, 1) + self.depth_list = val2list(depth_list, 1) + + self.ks_list.sort() + self.expand_ratio_list.sort() + self.depth_list.sort() + + base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280] + + final_expand_width = make_divisible( + base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = make_divisible( + base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + stride_stages = [1, 2, 2, 2, 1, 2] + act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"] + se_stages = [False, False, True, False, True, True] + n_block_list = [1] + [max(self.depth_list)] * 5 + width_list = [] + for base_width in base_stage_width[:-2]: + width = make_divisible( + base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + width_list.append(width) + + input_channel, first_block_dim = width_list[0], width_list[1] + # first conv layer + first_conv = ConvLayer( + 3, input_channel, kernel_size=3, stride=2, act_func="h_swish" + ) + first_block_conv = MBConvLayer( + in_channels=input_channel, + out_channels=first_block_dim, + kernel_size=3, + stride=stride_stages[0], + expand_ratio=1, + act_func=act_stages[0], + use_se=se_stages[0], + ) + first_block = ResidualBlock( + first_block_conv, + IdentityLayer(first_block_dim, first_block_dim) + if input_channel == first_block_dim + else None, + ) + + # inverted residual blocks + self.block_group_info = [] + blocks = [first_block] + _block_index = 1 + feature_dim = first_block_dim + + for width, n_block, s, act_func, use_se in zip( + width_list[2:], + n_block_list[1:], + stride_stages[1:], + act_stages[1:], + se_stages[1:], + ): + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + if i == 0: + stride = s + else: + stride = 1 + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=val2list(feature_dim), + out_channel_list=val2list(output_channel), + kernel_size_list=ks_list, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func=act_func, + use_se=use_se, + ) + if stride == 1 and feature_dim == output_channel: + shortcut = IdentityLayer(feature_dim, feature_dim) + else: + shortcut = None + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + feature_dim = output_channel + # final expand layer, feature mix layer & classifier + final_expand_layer = ConvLayer( + feature_dim, final_expand_width, kernel_size=1, act_func="h_swish" + ) + feature_mix_layer = ConvLayer( + final_expand_width, + last_channel, + kernel_size=1, + bias=False, + use_bn=False, + act_func="h_swish", + ) + + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(DYNMobileNetV3, self).__init__( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + """ MyNetwork required methods """ + + @staticmethod + def name(): + return "DYNMobileNetV3" + + def forward(self, x): + # first conv + x = self.first_conv(x) + # first block + x = self.blocks[0](x) + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + x = self.final_expand_layer(x) + x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling + x = self.feature_mix_layer(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + _str += self.blocks[0].module_str + "\n" + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + + _str += self.final_expand_layer.module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + "\n" + return _str + + @property + def config(self): + return { + "name": DYNMobileNetV3.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "final_expand_layer": self.final_expand_layer.config, + "feature_mix_layer": self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + @property + def grouped_block_index(self): + return self.block_group_info + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + if ".mobile_inverted_conv." in key: + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + if new_key in model_dict: + pass + elif ".bn.bn." in new_key: + new_key = new_key.replace(".bn.bn.", ".bn.") + elif ".conv.conv.weight" in new_key: + new_key = new_key.replace(".conv.conv.weight", ".conv.weight") + elif ".linear.linear." in new_key: + new_key = new_key.replace(".linear.linear.", ".linear.") + ############################################################################## + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(DYNMobileNetV3, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list) + ) + + def set_active_subnet(self, ks=None, e=None, d=None, **kwargs): + ks = val2list(ks, len(self.blocks) - 1) + expand_ratio = val2list(e, len(self.blocks) - 1) + depth = val2list(d, len(self.block_group_info)) + + for block, k, e in zip(self.blocks[1:], ks, expand_ratio): + if k is not None: + block.conv.active_kernel_size = k + if e is not None: + block.conv.active_expand_ratio = e + + for i, d in enumerate(depth): + if d is not None: + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + def set_constraint(self, include_list, constraint_type="depth"): + if constraint_type == "depth": + self.__dict__["_depth_include_list"] = include_list.copy() + elif constraint_type == "expand_ratio": + self.__dict__["_expand_include_list"] = include_list.copy() + elif constraint_type == "kernel_size": + self.__dict__["_ks_include_list"] = include_list.copy() + else: + raise NotImplementedError + + def clear_constraint(self): + self.__dict__["_depth_include_list"] = None + self.__dict__["_expand_include_list"] = None + self.__dict__["_ks_include_list"] = None + + def sample_active_subnet(self): + ks_candidates = ( + self.ks_list + if self.__dict__.get("_ks_include_list", None) is None + else self.__dict__["_ks_include_list"] + ) + expand_candidates = ( + self.expand_ratio_list + if self.__dict__.get("_expand_include_list", None) is None + else self.__dict__["_expand_include_list"] + ) + depth_candidates = ( + self.depth_list + if self.__dict__.get("_depth_include_list", None) is None + else self.__dict__["_depth_include_list"] + ) + + # sample kernel size + ks_setting = [] + if not isinstance(ks_candidates[0], list): + ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)] + for k_set in ks_candidates: + k = random.choice(k_set) + ks_setting.append(k) + + # sample expand ratio + expand_setting = [] + if not isinstance(expand_candidates[0], list): + expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)] + for e_set in expand_candidates: + e = random.choice(e_set) + expand_setting.append(e) + + # sample depth + depth_setting = [] + if not isinstance(depth_candidates[0], list): + depth_candidates = [ + depth_candidates for _ in range(len(self.block_group_info)) + ] + for d_set in depth_candidates: + d = random.choice(d_set) + depth_setting.append(d) + + self.set_active_subnet(ks_setting, expand_setting, depth_setting) + + return { + "ks": ks_setting, + "e": expand_setting, + "d": depth_setting, + } + + def get_active_subnet(self, preserve_weight=True): + first_conv = copy.deepcopy(self.first_conv) + blocks = [copy.deepcopy(self.blocks[0])] + + final_expand_layer = copy.deepcopy(self.final_expand_layer) + feature_mix_layer = copy.deepcopy(self.feature_mix_layer) + classifier = copy.deepcopy(self.classifier) + + input_channel = blocks[0].conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + ResidualBlock( + self.blocks[idx].conv.get_active_subnet( + input_channel, preserve_weight + ), + copy.deepcopy(self.blocks[idx].shortcut), + ) + ) + input_channel = stage_blocks[-1].conv.out_channels + blocks += stage_blocks + + _subnet = MobileNetV3( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + def get_active_net_config(self): + # first conv + first_conv_config = self.first_conv.config + first_block_config = self.blocks[0].config + final_expand_config = self.final_expand_layer.config + feature_mix_layer_config = self.feature_mix_layer.config + classifier_config = self.classifier.config + + block_config_list = [first_block_config] + input_channel = first_block_config["conv"]["out_channels"] + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + { + "name": ResidualBlock.__name__, + "conv": self.blocks[idx].conv.get_active_subnet_config( + input_channel + ), + "shortcut": self.blocks[idx].shortcut.config + if self.blocks[idx].shortcut is not None + else None, + } + ) + input_channel = self.blocks[idx].conv.active_out_channel + block_config_list += stage_blocks + + return { + "name": MobileNetV3.__name__, + "bn": self.get_bn_param(), + "first_conv": first_conv_config, + "blocks": block_config_list, + "final_expand_layer": final_expand_config, + "feature_mix_layer": feature_mix_layer_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks[1:]: + block.conv.re_organize_middle_weights(expand_ratio_stage) + + + +class DYNMobileNetV3_Cifar(MobileNetV3_Cifar): + def __init__( + self, + n_classes=10, + bn_param=(0.1, 1e-5), + dropout_rate=0.1, + base_stage_width=None, + width_mult=1.0, + ks_list=3, + expand_ratio_list=6, + depth_list=4, + ): + + self.width_mult = width_mult + self.ks_list = val2list(ks_list, 1) + self.expand_ratio_list = val2list(expand_ratio_list, 1) + self.depth_list = val2list(depth_list, 1) + + self.ks_list.sort() + self.expand_ratio_list.sort() + self.depth_list.sort() + + base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280] + + final_expand_width = make_divisible( + base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = make_divisible( + base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + stride_stages = [1, 1, 2, 2, 1, 2] + act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"] + se_stages = [False, False, True, False, True, True] + n_block_list = [1] + [max(self.depth_list)] * 5 + width_list = [] + for base_width in base_stage_width[:-2]: + width = make_divisible( + base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + width_list.append(width) + + input_channel, first_block_dim = width_list[0], width_list[1] + # first conv layer + first_conv = ConvLayer( + 3, input_channel, kernel_size=3, stride=1, act_func="h_swish" + ) + first_block_conv = MBConvLayer( + in_channels=input_channel, + out_channels=first_block_dim, + kernel_size=3, + stride=stride_stages[0], + expand_ratio=1, + act_func=act_stages[0], + use_se=se_stages[0], + ) + first_block = ResidualBlock( + first_block_conv, + IdentityLayer(first_block_dim, first_block_dim) + if input_channel == first_block_dim + else None, + ) + + # inverted residual blocks + self.block_group_info = [] + blocks = [first_block] + _block_index = 1 + feature_dim = first_block_dim + + for width, n_block, s, act_func, use_se in zip( + width_list[2:], + n_block_list[1:], + stride_stages[1:], + act_stages[1:], + se_stages[1:], + ): + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + if i == 0: + stride = s + else: + stride = 1 + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=val2list(feature_dim), + out_channel_list=val2list(output_channel), + kernel_size_list=ks_list, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func=act_func, + use_se=use_se, + ) + if stride == 1 and feature_dim == output_channel: + shortcut = IdentityLayer(feature_dim, feature_dim) + else: + shortcut = None + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + feature_dim = output_channel + # final expand layer, feature mix layer & classifier + final_expand_layer = ConvLayer( + feature_dim, final_expand_width, kernel_size=1, act_func="h_swish" + ) + feature_mix_layer = ConvLayer( + final_expand_width, + last_channel, + kernel_size=1, + bias=False, + use_bn=False, + act_func="h_swish", + ) + + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(DYNMobileNetV3_Cifar, self).__init__( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + """ MyNetwork required methods """ + + @staticmethod + def name(): + return "DYNMobileNetV3_Cifar" + + def forward(self, x): + # first conv + x = self.first_conv(x) + # first block + x = self.blocks[0](x) + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + x = self.final_expand_layer(x) + x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling + x = self.feature_mix_layer(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + _str += self.blocks[0].module_str + "\n" + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + + _str += self.final_expand_layer.module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + "\n" + return _str + + @property + def config(self): + return { + "name": DYNMobileNetV3_Cifar.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "final_expand_layer": self.final_expand_layer.config, + "feature_mix_layer": self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + @property + def grouped_block_index(self): + return self.block_group_info + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + if ".mobile_inverted_conv." in key: + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + if new_key in model_dict: + pass + elif ".bn.bn." in new_key: + new_key = new_key.replace(".bn.bn.", ".bn.") + elif ".conv.conv.weight" in new_key: + new_key = new_key.replace(".conv.conv.weight", ".conv.weight") + elif ".linear.linear." in new_key: + new_key = new_key.replace(".linear.linear.", ".linear.") + ############################################################################## + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(DYNMobileNetV3_Cifar, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list) + ) + + def set_active_subnet(self, ks=None, e=None, d=None, **kwargs): + ks = val2list(ks, len(self.blocks) - 1) + expand_ratio = val2list(e, len(self.blocks) - 1) + depth = val2list(d, len(self.block_group_info)) + + for block, k, e in zip(self.blocks[1:], ks, expand_ratio): + if k is not None: + block.conv.active_kernel_size = k + if e is not None: + block.conv.active_expand_ratio = e + + for i, d in enumerate(depth): + if d is not None: + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + def set_constraint(self, include_list, constraint_type="depth"): + if constraint_type == "depth": + self.__dict__["_depth_include_list"] = include_list.copy() + elif constraint_type == "expand_ratio": + self.__dict__["_expand_include_list"] = include_list.copy() + elif constraint_type == "kernel_size": + self.__dict__["_ks_include_list"] = include_list.copy() + else: + raise NotImplementedError + + def clear_constraint(self): + self.__dict__["_depth_include_list"] = None + self.__dict__["_expand_include_list"] = None + self.__dict__["_ks_include_list"] = None + + def sample_active_subnet(self): + ks_candidates = ( + self.ks_list + if self.__dict__.get("_ks_include_list", None) is None + else self.__dict__["_ks_include_list"] + ) + expand_candidates = ( + self.expand_ratio_list + if self.__dict__.get("_expand_include_list", None) is None + else self.__dict__["_expand_include_list"] + ) + depth_candidates = ( + self.depth_list + if self.__dict__.get("_depth_include_list", None) is None + else self.__dict__["_depth_include_list"] + ) + + # sample kernel size + ks_setting = [] + if not isinstance(ks_candidates[0], list): + ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)] + for k_set in ks_candidates: + k = random.choice(k_set) + ks_setting.append(k) + + # sample expand ratio + expand_setting = [] + if not isinstance(expand_candidates[0], list): + expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)] + for e_set in expand_candidates: + e = random.choice(e_set) + expand_setting.append(e) + + # sample depth + depth_setting = [] + if not isinstance(depth_candidates[0], list): + depth_candidates = [ + depth_candidates for _ in range(len(self.block_group_info)) + ] + for d_set in depth_candidates: + d = random.choice(d_set) + depth_setting.append(d) + + self.set_active_subnet(ks_setting, expand_setting, depth_setting) + + return { + "ks": ks_setting, + "e": expand_setting, + "d": depth_setting, + } + + def get_active_subnet(self, preserve_weight=True): + first_conv = copy.deepcopy(self.first_conv) + blocks = [copy.deepcopy(self.blocks[0])] + + final_expand_layer = copy.deepcopy(self.final_expand_layer) + feature_mix_layer = copy.deepcopy(self.feature_mix_layer) + classifier = copy.deepcopy(self.classifier) + + input_channel = blocks[0].conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + ResidualBlock( + self.blocks[idx].conv.get_active_subnet( + input_channel, preserve_weight + ), + copy.deepcopy(self.blocks[idx].shortcut), + ) + ) + input_channel = stage_blocks[-1].conv.out_channels + blocks += stage_blocks + + _subnet = MobileNetV3_Cifar( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + def get_active_net_config(self): + # first conv + first_conv_config = self.first_conv.config + first_block_config = self.blocks[0].config + final_expand_config = self.final_expand_layer.config + feature_mix_layer_config = self.feature_mix_layer.config + classifier_config = self.classifier.config + + block_config_list = [first_block_config] + input_channel = first_block_config["conv"]["out_channels"] + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + { + "name": ResidualBlock.__name__, + "conv": self.blocks[idx].conv.get_active_subnet_config( + input_channel + ), + "shortcut": self.blocks[idx].shortcut.config + if self.blocks[idx].shortcut is not None + else None, + } + ) + input_channel = self.blocks[idx].conv.active_out_channel + block_config_list += stage_blocks + + return { + "name": MobileNetV3_Cifar.__name__, + "bn": self.get_bn_param(), + "first_conv": first_conv_config, + "blocks": block_config_list, + "final_expand_layer": final_expand_config, + "feature_mix_layer": feature_mix_layer_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks[1:]: + block.conv.re_organize_middle_weights(expand_ratio_stage) diff --git a/proard/classification/elastic_nn/networks/dyn_proxyless.py b/proard/classification/elastic_nn/networks/dyn_proxyless.py new file mode 100644 index 0000000000000000000000000000000000000000..912e2826152fef11f9eac5809110a6ff3422be9b --- /dev/null +++ b/proard/classification/elastic_nn/networks/dyn_proxyless.py @@ -0,0 +1,774 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import copy +import random + +from proard.utils import make_divisible, val2list, MyNetwork +from proard.classification.elastic_nn.modules import DynamicMBConvLayer +from proard.utils.layers import ( + ConvLayer, + IdentityLayer, + LinearLayer, + MBConvLayer, + ResidualBlock, +) +from proard.classification.networks.proxyless_nets import ProxylessNASNets,ProxylessNASNets_Cifar + +__all__ = ["DYNProxylessNASNets","DYNProxylessNASNets_Cifar"] + + +class DYNProxylessNASNets(ProxylessNASNets): + def __init__( + self, + n_classes=1000, + bn_param=(0.1, 1e-3), + dropout_rate=0.1, + base_stage_width=None, + width_mult=1.0, + ks_list=3, + expand_ratio_list=6, + depth_list=4, + ): + + self.width_mult = width_mult + self.ks_list = val2list(ks_list, 1) + self.expand_ratio_list = val2list(expand_ratio_list, 1) + self.depth_list = val2list(depth_list, 1) + + self.ks_list.sort() + self.expand_ratio_list.sort() + self.depth_list.sort() + + if base_stage_width == "google": + # MobileNetV2 Stage Width + base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280] + else: + # ProxylessNAS Stage Width + base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280] + + input_channel = make_divisible( + base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + first_block_width = make_divisible( + base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = make_divisible( + base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=2, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + # first block + first_block_conv = MBConvLayer( + in_channels=input_channel, + out_channels=first_block_width, + kernel_size=3, + stride=1, + expand_ratio=1, + act_func="relu6", + ) + first_block = ResidualBlock(first_block_conv, None) + + input_channel = first_block_width + # inverted residual blocks + self.block_group_info = [] + blocks = [first_block] + _block_index = 1 + + stride_stages = [2, 2, 2, 1, 2, 1] + n_block_list = [max(self.depth_list)] * 5 + [1] + + width_list = [] + for base_width in base_stage_width[2:-1]: + width = make_divisible( + base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + width_list.append(width) + + for width, n_block, s in zip(width_list, n_block_list, stride_stages): + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + if i == 0: + stride = s + else: + stride = 1 + + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=val2list(input_channel, 1), + out_channel_list=val2list(output_channel, 1), + kernel_size_list=ks_list, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func="relu6", + ) + + if stride == 1 and input_channel == output_channel: + shortcut = IdentityLayer(input_channel, input_channel) + else: + shortcut = None + + mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut) + + blocks.append(mb_inverted_block) + input_channel = output_channel + # 1x1_conv before global average pooling + feature_mix_layer = ConvLayer( + input_channel, + last_channel, + kernel_size=1, + use_bn=True, + act_func="relu6", + ) + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(DYNProxylessNASNets, self).__init__( + first_conv, blocks, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + """ MyNetwork required methods """ + + @staticmethod + def name(): + return "DYNProxylessNASNets" + + def forward(self, x): + # first conv + x = self.first_conv(x) + # first block + x = self.blocks[0](x) + + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + + # feature_mix_layer + x = self.feature_mix_layer(x) + x = x.mean(3).mean(2) + + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + _str += self.blocks[0].module_str + "\n" + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + "\n" + return _str + + @property + def config(self): + return { + "name": DYNProxylessNASNets.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "feature_mix_layer": None + if self.feature_mix_layer is None + else self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + @property + def grouped_block_index(self): + return self.block_group_info + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + if ".mobile_inverted_conv." in key: + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + if new_key in model_dict: + pass + elif ".bn.bn." in new_key: + new_key = new_key.replace(".bn.bn.", ".bn.") + elif ".conv.conv.weight" in new_key: + new_key = new_key.replace(".conv.conv.weight", ".conv.weight") + elif ".linear.linear." in new_key: + new_key = new_key.replace(".linear.linear.", ".linear.") + ############################################################################## + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(DYNProxylessNASNets, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list) + ) + + def set_active_subnet(self, ks=None, e=None, d=None, **kwargs): + ks = val2list(ks, len(self.blocks) - 1) + expand_ratio = val2list(e, len(self.blocks) - 1) + depth = val2list(d, len(self.block_group_info)) + + for block, k, e in zip(self.blocks[1:], ks, expand_ratio): + if k is not None: + block.conv.active_kernel_size = k + if e is not None: + block.conv.active_expand_ratio = e + + for i, d in enumerate(depth): + if d is not None: + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + def set_constraint(self, include_list, constraint_type="depth"): + if constraint_type == "depth": + self.__dict__["_depth_include_list"] = include_list.copy() + elif constraint_type == "expand_ratio": + self.__dict__["_expand_include_list"] = include_list.copy() + elif constraint_type == "kernel_size": + self.__dict__["_ks_include_list"] = include_list.copy() + else: + raise NotImplementedError + + def clear_constraint(self): + self.__dict__["_depth_include_list"] = None + self.__dict__["_expand_include_list"] = None + self.__dict__["_ks_include_list"] = None + + def sample_active_subnet(self): + ks_candidates = ( + self.ks_list + if self.__dict__.get("_ks_include_list", None) is None + else self.__dict__["_ks_include_list"] + ) + expand_candidates = ( + self.expand_ratio_list + if self.__dict__.get("_expand_include_list", None) is None + else self.__dict__["_expand_include_list"] + ) + depth_candidates = ( + self.depth_list + if self.__dict__.get("_depth_include_list", None) is None + else self.__dict__["_depth_include_list"] + ) + + # sample kernel size + ks_setting = [] + if not isinstance(ks_candidates[0], list): + ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)] + for k_set in ks_candidates: + k = random.choice(k_set) + ks_setting.append(k) + + # sample expand ratio + expand_setting = [] + if not isinstance(expand_candidates[0], list): + expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)] + for e_set in expand_candidates: + e = random.choice(e_set) + expand_setting.append(e) + + # sample depth + depth_setting = [] + if not isinstance(depth_candidates[0], list): + depth_candidates = [ + depth_candidates for _ in range(len(self.block_group_info)) + ] + for d_set in depth_candidates: + d = random.choice(d_set) + depth_setting.append(d) + + depth_setting[-1] = 1 + self.set_active_subnet(ks_setting, expand_setting, depth_setting) + + return { + "ks": ks_setting, + "e": expand_setting, + "d": depth_setting, + } + + def get_active_subnet(self, preserve_weight=True): + first_conv = copy.deepcopy(self.first_conv) + blocks = [copy.deepcopy(self.blocks[0])] + feature_mix_layer = copy.deepcopy(self.feature_mix_layer) + classifier = copy.deepcopy(self.classifier) + + input_channel = blocks[0].conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + ResidualBlock( + self.blocks[idx].conv.get_active_subnet( + input_channel, preserve_weight + ), + copy.deepcopy(self.blocks[idx].shortcut), + ) + ) + input_channel = stage_blocks[-1].conv.out_channels + blocks += stage_blocks + + _subnet = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + def get_active_net_config(self): + first_conv_config = self.first_conv.config + first_block_config = self.blocks[0].config + feature_mix_layer_config = self.feature_mix_layer.config + classifier_config = self.classifier.config + + block_config_list = [first_block_config] + input_channel = first_block_config["conv"]["out_channels"] + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + { + "name": ResidualBlock.__name__, + "conv": self.blocks[idx].conv.get_active_subnet_config( + input_channel + ), + "shortcut": self.blocks[idx].shortcut.config + if self.blocks[idx].shortcut is not None + else None, + } + ) + try: + input_channel = self.blocks[idx].conv.active_out_channel + except Exception: + input_channel = self.blocks[idx].conv.out_channels + block_config_list += stage_blocks + + return { + "name": ProxylessNASNets.__name__, + "bn": self.get_bn_param(), + "first_conv": first_conv_config, + "blocks": block_config_list, + "feature_mix_layer": feature_mix_layer_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks[1:]: + block.conv.re_organize_middle_weights(expand_ratio_stage) + + + +class DYNProxylessNASNets_Cifar(ProxylessNASNets_Cifar): + def __init__( + self, + n_classes=10, + bn_param=(0.1, 1e-3), + dropout_rate=0.1, + base_stage_width=None, + width_mult=1.0, + ks_list=3, + expand_ratio_list=6, + depth_list=4, + ): + + self.width_mult = width_mult + self.ks_list = val2list(ks_list, 1) + self.expand_ratio_list = val2list(expand_ratio_list, 1) + self.depth_list = val2list(depth_list, 1) + + self.ks_list.sort() + self.expand_ratio_list.sort() + self.depth_list.sort() + + if base_stage_width == "MBV2": + # MobileNetV2 Stage Width + base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280] + else: + # ProxylessNAS Stage Width + base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280] + + input_channel = make_divisible( + base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + first_block_width = make_divisible( + base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = make_divisible( + base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=1, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + # first block + first_block_conv = MBConvLayer( + in_channels=input_channel, + out_channels=first_block_width, + kernel_size=3, + stride=1, + expand_ratio=1, + act_func="relu6", + ) + first_block = ResidualBlock(first_block_conv, None) + + input_channel = first_block_width + # inverted residual blocks + self.block_group_info = [] + blocks = [first_block] + _block_index = 1 + + stride_stages = [1, 2, 2, 1, 2, 1] + n_block_list = [max(self.depth_list)] * 5 + [1] + + width_list = [] + for base_width in base_stage_width[2:-1]: + width = make_divisible( + base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + width_list.append(width) + + for width, n_block, s in zip(width_list, n_block_list, stride_stages): + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + if i == 0: + stride = s + else: + stride = 1 + + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=val2list(input_channel, 1), + out_channel_list=val2list(output_channel, 1), + kernel_size_list=ks_list, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func="relu6", + ) + + if stride == 1 and input_channel == output_channel: + shortcut = IdentityLayer(input_channel, input_channel) + else: + shortcut = None + + mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut) + + blocks.append(mb_inverted_block) + input_channel = output_channel + # 1x1_conv before global average pooling + feature_mix_layer = ConvLayer( + input_channel, + last_channel, + kernel_size=1, + use_bn=True, + act_func="relu6", + ) + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(DYNProxylessNASNets_Cifar, self).__init__( + first_conv, blocks, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + """ MyNetwork required methods """ + + @staticmethod + def name(): + return "DYNProxylessNASNets_Cifar" + + def forward(self, x): + # first conv + x = self.first_conv(x) + # first block + x = self.blocks[0](x) + + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + + # feature_mix_layer + x = self.feature_mix_layer(x) + x = x.mean(3).mean(2) + + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + _str += self.blocks[0].module_str + "\n" + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + "\n" + return _str + + @property + def config(self): + return { + "name": DYNProxylessNASNets_Cifar.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "feature_mix_layer": None + if self.feature_mix_layer is None + else self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + @property + def grouped_block_index(self): + return self.block_group_info + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + if ".mobile_inverted_conv." in key: + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + if new_key in model_dict: + pass + elif ".bn.bn." in new_key: + new_key = new_key.replace(".bn.bn.", ".bn.") + elif ".conv.conv.weight" in new_key: + new_key = new_key.replace(".conv.conv.weight", ".conv.weight") + elif ".linear.linear." in new_key: + new_key = new_key.replace(".linear.linear.", ".linear.") + ############################################################################## + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(DYNProxylessNASNets_Cifar, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list) + ) + + def set_active_subnet(self, ks=None, e=None, d=None, **kwargs): + ks = val2list(ks, len(self.blocks) - 1) + expand_ratio = val2list(e, len(self.blocks) - 1) + depth = val2list(d, len(self.block_group_info)) + + for block, k, e in zip(self.blocks[1:], ks, expand_ratio): + if k is not None: + block.conv.active_kernel_size = k + if e is not None: + block.conv.active_expand_ratio = e + + for i, d in enumerate(depth): + if d is not None: + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + def set_constraint(self, include_list, constraint_type="depth"): + if constraint_type == "depth": + self.__dict__["_depth_include_list"] = include_list.copy() + elif constraint_type == "expand_ratio": + self.__dict__["_expand_include_list"] = include_list.copy() + elif constraint_type == "kernel_size": + self.__dict__["_ks_include_list"] = include_list.copy() + else: + raise NotImplementedError + + def clear_constraint(self): + self.__dict__["_depth_include_list"] = None + self.__dict__["_expand_include_list"] = None + self.__dict__["_ks_include_list"] = None + + def sample_active_subnet(self): + ks_candidates = ( + self.ks_list + if self.__dict__.get("_ks_include_list", None) is None + else self.__dict__["_ks_include_list"] + ) + expand_candidates = ( + self.expand_ratio_list + if self.__dict__.get("_expand_include_list", None) is None + else self.__dict__["_expand_include_list"] + ) + depth_candidates = ( + self.depth_list + if self.__dict__.get("_depth_include_list", None) is None + else self.__dict__["_depth_include_list"] + ) + + # sample kernel size + ks_setting = [] + if not isinstance(ks_candidates[0], list): + ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)] + for k_set in ks_candidates: + k = random.choice(k_set) + ks_setting.append(k) + + # sample expand ratio + expand_setting = [] + if not isinstance(expand_candidates[0], list): + expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)] + for e_set in expand_candidates: + e = random.choice(e_set) + expand_setting.append(e) + + # sample depth + depth_setting = [] + if not isinstance(depth_candidates[0], list): + depth_candidates = [ + depth_candidates for _ in range(len(self.block_group_info)) + ] + for d_set in depth_candidates: + d = random.choice(d_set) + depth_setting.append(d) + + depth_setting[-1] = 1 + self.set_active_subnet(ks_setting, expand_setting, depth_setting) + + return { + "ks": ks_setting, + "e": expand_setting, + "d": depth_setting, + } + + def get_active_subnet(self, preserve_weight=True): + first_conv = copy.deepcopy(self.first_conv) + blocks = [copy.deepcopy(self.blocks[0])] + feature_mix_layer = copy.deepcopy(self.feature_mix_layer) + classifier = copy.deepcopy(self.classifier) + + input_channel = blocks[0].conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + ResidualBlock( + self.blocks[idx].conv.get_active_subnet( + input_channel, preserve_weight + ), + copy.deepcopy(self.blocks[idx].shortcut), + ) + ) + input_channel = stage_blocks[-1].conv.out_channels + blocks += stage_blocks + + _subnet = ProxylessNASNets_Cifar(first_conv, blocks, feature_mix_layer, classifier) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + def get_active_net_config(self): + first_conv_config = self.first_conv.config + first_block_config = self.blocks[0].config + feature_mix_layer_config = self.feature_mix_layer.config + classifier_config = self.classifier.config + + block_config_list = [first_block_config] + input_channel = first_block_config["conv"]["out_channels"] + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + { + "name": ResidualBlock.__name__, + "conv": self.blocks[idx].conv.get_active_subnet_config( + input_channel + ), + "shortcut": self.blocks[idx].shortcut.config + if self.blocks[idx].shortcut is not None + else None, + } + ) + try: + input_channel = self.blocks[idx].conv.active_out_channel + except Exception: + input_channel = self.blocks[idx].conv.out_channels + block_config_list += stage_blocks + + return { + "name": ProxylessNASNets_Cifar.__name__, + "bn": self.get_bn_param(), + "first_conv": first_conv_config, + "blocks": block_config_list, + "feature_mix_layer": feature_mix_layer_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks[1:]: + block.conv.re_organize_middle_weights(expand_ratio_stage) diff --git a/proard/classification/elastic_nn/networks/dyn_resnets.py b/proard/classification/elastic_nn/networks/dyn_resnets.py new file mode 100644 index 0000000000000000000000000000000000000000..90e91dbd81507e50e6868f3377e1aabb563704b3 --- /dev/null +++ b/proard/classification/elastic_nn/networks/dyn_resnets.py @@ -0,0 +1,678 @@ +import random + +from proard.classification.elastic_nn.modules.dynamic_layers import ( + DynamicConvLayer, + DynamicLinearLayer, +) +from proard.classification.elastic_nn.modules.dynamic_layers import ( + DynamicResNetBottleneckBlock, +) +from proard.utils.layers import IdentityLayer, ResidualBlock +from proard.classification.networks import ResNets,ResNets_Cifar +from proard.utils import make_divisible, val2list, MyNetwork + +__all__ = ["DYNResNets","DYNResNets_Cifar"] + + +class DYNResNets(ResNets): + def __init__( + self, + n_classes=1000, + bn_param=(0.1, 1e-5), + dropout_rate=0, + depth_list=2, + expand_ratio_list=0.25, + width_mult_list=1.0, + ): + + self.depth_list = val2list(depth_list) + self.expand_ratio_list = val2list(expand_ratio_list) + self.width_mult_list = val2list(width_mult_list) + # sort + self.depth_list.sort() + self.expand_ratio_list.sort() + self.width_mult_list.sort() + + input_channel = [ + make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + for width_mult in self.width_mult_list + ] + mid_input_channel = [ + make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE) + for channel in input_channel + ] + + stage_width_list = ResNets.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = [ + make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + for width_mult in self.width_mult_list + ] + + n_block_list = [ + base_depth + max(self.depth_list) for base_depth in ResNets.BASE_DEPTH_LIST + ] + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + DynamicConvLayer( + val2list(3), + mid_input_channel, + 3, + stride=2, + use_bn=True, + act_func="relu", + ), + ResidualBlock( + DynamicConvLayer( + mid_input_channel, + mid_input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + IdentityLayer(mid_input_channel, mid_input_channel), + ), + DynamicConvLayer( + mid_input_channel, + input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + ] + + # blocks + blocks = [] + for d, width, s in zip(n_block_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = DynamicResNetBottleneckBlock( + input_channel, + width, + expand_ratio_list=self.expand_ratio_list, + kernel_size=3, + stride=stride, + act_func="relu", + downsample_mode="avgpool_conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = DynamicLinearLayer( + input_channel, n_classes, dropout_rate=dropout_rate + ) + + super(DYNResNets, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) + + # runtime_depth + self.input_stem_skipping = 0 + self.runtime_depth = [0] * len(n_block_list) + + @property + def ks_list(self): + return [3] + + @staticmethod + def name(): + return "DYNResNets" + + def forward(self, x): + for layer in self.input_stem: + if ( + self.input_stem_skipping > 0 + and isinstance(layer, ResidualBlock) + and isinstance(layer.shortcut, IdentityLayer) + ): + pass + else: + x = layer(x) + x = self.max_pooling(x) + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + x = self.blocks[idx](x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = "" + for layer in self.input_stem: + if ( + self.input_stem_skipping > 0 + and isinstance(layer, ResidualBlock) + and isinstance(layer.shortcut, IdentityLayer) + ): + pass + else: + _str += layer.module_str + "\n" + _str += "max_pooling(ks=3, stride=2)\n" + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": DYNResNets.__name__, + "bn": self.get_bn_param(), + "input_stem": [layer.config for layer in self.input_stem], + "blocks": [block.config for block in self.blocks], + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + new_key = key + if new_key in model_dict: + pass + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(DYNResNets, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + d=max(self.depth_list), + e=max(self.expand_ratio_list), + w=len(self.width_mult_list) - 1, + ) + + def set_active_subnet(self, d=None, e=None, w=None, **kwargs): + depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1) + expand_ratio = val2list(e, len(self.blocks)) + width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2) + + for block, e in zip(self.blocks, expand_ratio): + if e is not None: + block.active_expand_ratio = e + + if width_mult[0] is not None: + self.input_stem[1].conv.active_out_channel = self.input_stem[ + 0 + ].active_out_channel = self.input_stem[0].out_channel_list[width_mult[0]] + if width_mult[1] is not None: + self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[ + width_mult[1] + ] + + if depth[0] is not None: + self.input_stem_skipping = depth[0] != max(self.depth_list) + for stage_id, (block_idx, d, w) in enumerate( + zip(self.grouped_block_index, depth[1:], width_mult[2:]) + ): + if d is not None: + self.runtime_depth[stage_id] = max(self.depth_list) - d + if w is not None: + for idx in block_idx: + self.blocks[idx].active_out_channel = self.blocks[ + idx + ].out_channel_list[w] + + def sample_active_subnet(self): + # sample expand ratio + expand_setting = [] + for block in self.blocks: + expand_setting.append(random.choice(block.expand_ratio_list)) + + # sample depth + depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])] + for stage_id in range(len(ResNets.BASE_DEPTH_LIST)): + depth_setting.append(random.choice(self.depth_list)) + + # sample width_mult + width_mult_setting = [ + random.choice(list(range(len(self.input_stem[0].out_channel_list)))), + random.choice(list(range(len(self.input_stem[2].out_channel_list)))), + ] + for stage_id, block_idx in enumerate(self.grouped_block_index): + stage_first_block = self.blocks[block_idx[0]] + width_mult_setting.append( + random.choice(list(range(len(stage_first_block.out_channel_list)))) + ) + + arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting} + self.set_active_subnet(**arch_config) + return arch_config + + def get_active_subnet(self, preserve_weight=True): + input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)] + if self.input_stem_skipping <= 0: + input_stem.append( + ResidualBlock( + self.input_stem[1].conv.get_active_subnet( + self.input_stem[0].active_out_channel, preserve_weight + ), + IdentityLayer( + self.input_stem[0].active_out_channel, + self.input_stem[0].active_out_channel, + ), + ) + ) + input_stem.append( + self.input_stem[2].get_active_subnet( + self.input_stem[0].active_out_channel, preserve_weight + ) + ) + input_channel = self.input_stem[2].active_out_channel + + blocks = [] + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + blocks.append( + self.blocks[idx].get_active_subnet(input_channel, preserve_weight) + ) + input_channel = self.blocks[idx].active_out_channel + classifier = self.classifier.get_active_subnet(input_channel, preserve_weight) + subnet = ResNets(input_stem, blocks, classifier) + + subnet.set_bn_param(**self.get_bn_param()) + return subnet + + def get_active_net_config(self): + input_stem_config = [self.input_stem[0].get_active_subnet_config(3)] + if self.input_stem_skipping <= 0: + input_stem_config.append( + { + "name": ResidualBlock.__name__, + "conv": self.input_stem[1].conv.get_active_subnet_config( + self.input_stem[0].active_out_channel + ), + "shortcut": IdentityLayer( + self.input_stem[0].active_out_channel, + self.input_stem[0].active_out_channel, + ), + } + ) + input_stem_config.append( + self.input_stem[2].get_active_subnet_config( + self.input_stem[0].active_out_channel + ) + ) + input_channel = self.input_stem[2].active_out_channel + + blocks_config = [] + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + blocks_config.append( + self.blocks[idx].get_active_subnet_config(input_channel) + ) + input_channel = self.blocks[idx].active_out_channel + classifier_config = self.classifier.get_active_subnet_config(input_channel) + return { + "name": ResNets.__name__, + "bn": self.get_bn_param(), + "input_stem": input_stem_config, + "blocks": blocks_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks: + block.re_organize_middle_weights(expand_ratio_stage) + + + +class DYNResNets_Cifar(ResNets_Cifar): + def __init__( + self, + n_classes=10, + bn_param=(0.1, 1e-5), + dropout_rate=0, + depth_list=0, + expand_ratio_list=0.25, + width_mult_list=1.0, + ): + + self.depth_list = val2list(depth_list) + self.expand_ratio_list = val2list(expand_ratio_list) + self.width_mult_list = val2list(width_mult_list) + # sort + self.depth_list.sort() + self.expand_ratio_list.sort() + self.width_mult_list.sort() + + input_channel = [ + make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + for width_mult in self.width_mult_list + ] + mid_input_channel = [ + make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE) + for channel in input_channel + ] + + stage_width_list = ResNets_Cifar.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = [ + make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + for width_mult in self.width_mult_list + ] + + n_block_list = [ + base_depth + max(self.depth_list) for base_depth in ResNets_Cifar.BASE_DEPTH_LIST + ] + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + DynamicConvLayer( + val2list(3), + mid_input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + ResidualBlock( + DynamicConvLayer( + mid_input_channel, + mid_input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + IdentityLayer(mid_input_channel, mid_input_channel), + ), + DynamicConvLayer( + mid_input_channel, + input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + ] + + # blocks + blocks = [] + for d, width, s in zip(n_block_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = DynamicResNetBottleneckBlock( + input_channel, + width, + expand_ratio_list=self.expand_ratio_list, + kernel_size=3, + stride=stride, + act_func="relu", + downsample_mode="conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = DynamicLinearLayer( + input_channel, n_classes, dropout_rate=dropout_rate + ) + + super(DYNResNets_Cifar, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) + + # runtime_depth + self.input_stem_skipping = 0 + self.runtime_depth = [0] * len(n_block_list) + + @property + def ks_list(self): + return [3] + + @staticmethod + def name(): + return "DYNResNets_Cifar" + + def forward(self, x): + for layer in self.input_stem: + if ( + self.input_stem_skipping > 0 + and isinstance(layer, ResidualBlock) + and isinstance(layer.shortcut, IdentityLayer) + ): + pass + else: + x = layer(x) + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + x = self.blocks[idx](x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = "" + for layer in self.input_stem: + if ( + self.input_stem_skipping > 0 + and isinstance(layer, ResidualBlock) + and isinstance(layer.shortcut, IdentityLayer) + ): + pass + else: + _str += layer.module_str + "\n" + # _str += "max_pooling(ks=3, stride=2)\n" + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": DYNResNets_Cifar.__name__, + "bn": self.get_bn_param(), + "input_stem": [layer.config for layer in self.input_stem], + "blocks": [block.config for block in self.blocks], + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + new_key = key + if new_key in model_dict: + pass + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(DYNResNets_Cifar, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + d=max(self.depth_list), + e=max(self.expand_ratio_list), + w=len(self.width_mult_list) - 1, + ) + + def set_active_subnet(self, d=None, e=None, w=None, **kwargs): + depth = val2list(d, len(ResNets_Cifar.BASE_DEPTH_LIST) + 1) + expand_ratio = val2list(e, len(self.blocks)) + width_mult = val2list(w, len(ResNets_Cifar.BASE_DEPTH_LIST) + 2) + + for block, e in zip(self.blocks, expand_ratio): + if e is not None: + block.active_expand_ratio = e + + if width_mult[0] is not None: + self.input_stem[1].conv.active_out_channel = self.input_stem[ + 0 + ].active_out_channel = self.input_stem[0].out_channel_list[int(width_mult[0])] + if width_mult[1] is not None: + self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[ + int(width_mult[1]) + ] + + if depth[0] is not None: + self.input_stem_skipping = depth[0] != max(self.depth_list) + for stage_id, (block_idx, d, w) in enumerate( + zip(self.grouped_block_index, depth[1:], width_mult[2:]) + ): + if d is not None: + self.runtime_depth[stage_id] = max(self.depth_list) - d + if w is not None: + for idx in block_idx: + self.blocks[idx].active_out_channel = self.blocks[ + idx + ].out_channel_list[int(w)] + + def sample_active_subnet(self): + # sample expand ratio + expand_setting = [] + for block in self.blocks: + expand_setting.append(random.choice(block.expand_ratio_list)) + + # sample depth + depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])] + for stage_id in range(len(ResNets_Cifar.BASE_DEPTH_LIST)): + depth_setting.append(random.choice(self.depth_list)) + + # sample width_mult + width_mult_setting = [ + random.choice(list(range(len(self.input_stem[0].out_channel_list)))), + random.choice(list(range(len(self.input_stem[2].out_channel_list)))), + ] + for stage_id, block_idx in enumerate(self.grouped_block_index): + stage_first_block = self.blocks[block_idx[0]] + width_mult_setting.append( + random.choice(list(range(len(stage_first_block.out_channel_list)))) + ) + + arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting} + self.set_active_subnet(**arch_config) + return arch_config + + def get_active_subnet(self, preserve_weight=True): + input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)] + if self.input_stem_skipping <= 0: + input_stem.append( + ResidualBlock( + self.input_stem[1].conv.get_active_subnet( + self.input_stem[0].active_out_channel, preserve_weight + ), + IdentityLayer( + self.input_stem[0].active_out_channel, + self.input_stem[0].active_out_channel, + ), + ) + ) + input_stem.append( + self.input_stem[2].get_active_subnet( + self.input_stem[0].active_out_channel, preserve_weight + ) + ) + input_channel = self.input_stem[2].active_out_channel + + blocks = [] + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + blocks.append( + self.blocks[idx].get_active_subnet(input_channel, preserve_weight) + ) + input_channel = self.blocks[idx].active_out_channel + classifier = self.classifier.get_active_subnet(input_channel, preserve_weight) + subnet = ResNets_Cifar(input_stem, blocks, classifier) + + subnet.set_bn_param(**self.get_bn_param()) + return subnet + + def get_active_net_config(self): + input_stem_config = [self.input_stem[0].get_active_subnet_config(3)] + if self.input_stem_skipping <= 0: + input_stem_config.append( + { + "name": ResidualBlock.__name__, + "conv": self.input_stem[1].conv.get_active_subnet_config( + self.input_stem[0].active_out_channel + ), + "shortcut": IdentityLayer( + self.input_stem[0].active_out_channel, + self.input_stem[0].active_out_channel, + ), + } + ) + input_stem_config.append( + self.input_stem[2].get_active_subnet_config( + self.input_stem[0].active_out_channel + ) + ) + input_channel = self.input_stem[2].active_out_channel + + blocks_config = [] + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - int(depth_param)] + for idx in active_idx: + blocks_config.append( + self.blocks[idx].get_active_subnet_config(input_channel) + ) + input_channel = self.blocks[idx].active_out_channel + classifier_config = self.classifier.get_active_subnet_config(input_channel) + return { + "name": ResNets_Cifar.__name__, + "bn": self.get_bn_param(), + "input_stem": input_stem_config, + "blocks": blocks_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks: + block.re_organize_middle_weights(expand_ratio_stage) diff --git a/proard/classification/elastic_nn/training/__init__.py b/proard/classification/elastic_nn/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd8470bf42f2b96f7b950ede3d292260207b5a4 --- /dev/null +++ b/proard/classification/elastic_nn/training/__init__.py @@ -0,0 +1,6 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .progressive_shrinking import * +from .progressive_shrinking import * diff --git a/proard/classification/elastic_nn/training/progressive_shrinking.py b/proard/classification/elastic_nn/training/progressive_shrinking.py new file mode 100644 index 0000000000000000000000000000000000000000..26f4cbccf35dced4bd4b9633f22c5d99797397d4 --- /dev/null +++ b/proard/classification/elastic_nn/training/progressive_shrinking.py @@ -0,0 +1,463 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import torch.nn as nn +import random +import time +import torch +import torch.nn.functional as F +from tqdm import tqdm +from attacks.utils import ctx_noparamgrad_and_eval +from robust_loss.rslad import rslad_inner_loss,kl_loss +from robust_loss.trades import trades_loss +from attacks import create_attack +import copy +from proard.utils import AverageMeter, cross_entropy_loss_with_soft_target +from proard.utils import ( + DistributedMetric, + list_mean, + subset_mean, + val2list, + MyRandomResizedCrop, +) +from proard.classification.run_manager import DistributedRunManager + +__all__ = [ + "validate", + "train_one_epoch", + "train", + "load_models", + "train_elastic_depth", + "train_elastic_expand", + "train_elastic_width_mult", +] + + +def validate( + run_manager, + epoch=0, + is_test=False, + image_size_list=None, + ks_list=None, + expand_ratio_list=None, + depth_list=None, + width_mult_list=None, + additional_setting=None, +): + dynamic_net = run_manager.net + if isinstance(dynamic_net, nn.DataParallel): + dynamic_net = dynamic_net.module + + dynamic_net.eval() + + if image_size_list is None: + image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1) + if ks_list is None: + ks_list = dynamic_net.ks_list + if expand_ratio_list is None: + expand_ratio_list = dynamic_net.expand_ratio_list + if depth_list is None: + depth_list = dynamic_net.depth_list + if width_mult_list is not None: + if "width_mult_list" in dynamic_net.__dict__: + width_mult_list = list(range(len(dynamic_net.width_mult_list))) + else: + width_mult_list = [0] + + subnet_settings = [] + for d in depth_list: + for e in expand_ratio_list: + for k in ks_list: + for w in width_mult_list: + for img_size in image_size_list: + subnet_settings.append( + [ + { + "image_size": img_size, + "d": d, + "e": e, + "ks": k, + "w": w, + }, + "R%s-D%s-E%s-K%s-W%s" % (img_size, d, e, k, w), + ] + ) + if additional_setting is not None: + subnet_settings += additional_setting + + losses_of_subnets, top1_of_subnets, top5_of_subnets , robust1_of_subnets , robust5_of_subnets = [], [], [],[],[] + + valid_log = "" + for setting, name in subnet_settings: + run_manager.write_log( + "-" * 30 + " Validate %s " % name + "-" * 30, "train", should_print=False + ) + run_manager.run_config.data_provider.assign_active_img_size( + setting.pop("image_size") + ) + dynamic_net.set_active_subnet(**setting) + run_manager.write_log(dynamic_net.module_str, "train", should_print=False) + + run_manager.reset_running_statistics(dynamic_net) + loss, (top1, top5,robust1,robust5) = run_manager.validate( + epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net + ) + losses_of_subnets.append(loss) + top1_of_subnets.append(top1) + top5_of_subnets.append(top5) + robust1_of_subnets.append(robust1) + robust5_of_subnets.append(robust5) + valid_log += "%s (%.3f) (%.3f), " % (name, top1,robust1) + + return ( + list_mean(losses_of_subnets), + list_mean(top1_of_subnets), + list_mean(top5_of_subnets), + list_mean(robust1_of_subnets), + list_mean(robust5_of_subnets), + valid_log, + ) + + +def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0): + dynamic_net = run_manager.network + distributed = isinstance(run_manager, DistributedRunManager) + + # switch to train mode + dynamic_net.train() + if distributed: + run_manager.run_config.train_loader.sampler.set_epoch(epoch) + MyRandomResizedCrop.EPOCH = epoch + + nBatch = len(run_manager.run_config.train_loader) + + data_time = AverageMeter() + losses = DistributedMetric("train_loss") if distributed else AverageMeter() + metric_dict = run_manager.get_metric_dict() + + with tqdm( + total=nBatch, + desc="Train Epoch #{}".format(epoch + 1), + disable=distributed and not run_manager.is_root, + ) as t: + end = time.time() + subnet_str = "" + j=0 + for _ in range(args.dynamic_batch_size): + # set random seed before sampling + subnet_seed = int("%d%.3d%.3d" % (epoch * nBatch + j, _, 0)) + random.seed(subnet_seed) + subnet_settings = dynamic_net.sample_active_subnet() + subnet_str += ( + "%d: " % _ + + ",".join( + [ + "%s_%s" + % ( + key, + "%.1f" % subset_mean(val, 0) + if isinstance(val, list) + else val, + ) + for key, val in subnet_settings.items() + ] + ) + + " || " + ) + + for i, (images, labels) in enumerate(run_manager.run_config.train_loader): + MyRandomResizedCrop.BATCH = i + data_time.update(time.time() - end) + if epoch < warmup_epochs: + new_lr = run_manager.run_config.warmup_adjust_learning_rate( + run_manager.optimizer, + warmup_epochs * nBatch, + nBatch, + epoch, + i, + warmup_lr, + ) + else: + new_lr = run_manager.run_config.adjust_learning_rate( + run_manager.optimizer, epoch - warmup_epochs, i, nBatch + ) + + images, labels = images.cuda(), labels.cuda() + target = labels + + # soft target + if args.kd_ratio > 0: + args.teacher_model.eval() + with torch.no_grad(): + soft_logits = args.teacher_model(images).detach() + soft_label = F.softmax(soft_logits, dim=1) + + # clean gradients + dynamic_net.zero_grad() + + loss_of_subnets = [] + # compute output + + + output = dynamic_net(images) + + if args.kd_ratio == 0: + if run_manager.run_config.robust_mode: + loss = run_manager.train_criterion(dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train,run_manager.run_config.distance_train) + loss_type = run_manager.run_config.train_criterion_loss.__name__ + else: + loss = torch.nn.CrossEntropyLoss(output,labels) + loss_type = 'ce' + else: + if run_manager.run_config.robust_mode: + loss = run_manager.kd_criterion(args.teacher_model,dynamic_net,images,labels,run_manager.optimizer,run_manager.run_config.step_size_train,run_manager.run_config.epsilon_train,run_manager.run_config.num_steps_train,run_manager.run_config.beta_train) + loss_type = run_manager.run_config.kd_criterion_loss.__name__ + else: + if args.kd_type == "ce": + kd_loss = cross_entropy_loss_with_soft_target( + output, soft_label + ) + else: + kd_loss = F.mse_loss(output, soft_logits) + loss = args.kd_ratio * kd_loss + loss + loss_type = "%.1fkd+ce" % args.kd_ratio + # measure accuracy and record loss + loss_of_subnets.append(loss) + run_manager.update_metric(metric_dict, output,output, target) + + loss.backward() + run_manager.optimizer.step() + + losses.update(list_mean(loss_of_subnets), images.size(0)) + + t.set_postfix( + { + "loss": losses.avg.item(), + **run_manager.get_metric_vals(metric_dict, return_dict=True), + "R": images.size(2), + "lr": new_lr, + "loss_type": loss_type, + "seed": str(subnet_seed), + "str": subnet_str, + "data_time": data_time.avg, + } + ) + t.update(1) + end = time.time() + j+=1 + return losses.avg.item(), run_manager.get_metric_vals(metric_dict) + + +def train(run_manager, args, validate_func=None): + distributed = isinstance(run_manager, DistributedRunManager) + if validate_func is None: + validate_func = validate + + for epoch in range( + run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs + ): + train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = train_one_epoch( + run_manager, args, epoch, args.warmup_epochs, args.warmup_lr + ) + + if (epoch + 1) % args.validation_frequency == 0: + val_loss, val_acc, val_acc5, val_robust1, val_robust5, _val_log = validate_func( + run_manager, epoch=epoch, is_test=True + ) + # best_acc + is_best = val_acc > run_manager.best_acc + is_best_robust = val_robust1 > run_manager.best_robustness + run_manager.best_acc = max(run_manager.best_acc, val_acc) + run_manager.best_robustness = max(run_manager.best_robustness, val_robust1) + if not distributed or run_manager.is_root: + val_log = ( + "Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f}) , robust-1 = {4:.3f} ({5:.3f}) ".format( + epoch + 1 - args.warmup_epochs, + run_manager.run_config.n_epochs, + val_loss, + val_acc, + run_manager.best_acc, + val_robust1, + run_manager.best_robustness, + ) + ) + val_log += ", Train top-1 {top1:.3f}, Train robust-1 {robust1:.3f}, Train loss {loss:.3f}\t".format( + top1=train_top1, robust1 = train_robust1, loss=train_loss + ) + val_log += _val_log + run_manager.write_log(val_log, "valid", should_print=False) + + run_manager.save_model( + { + "epoch": epoch, + "best_acc": run_manager.best_acc, + "optimizer": run_manager.optimizer.state_dict(), + "state_dict": run_manager.network.state_dict(), + }, + is_best=is_best, + ) + + +def load_models(run_manager, dynamic_net, model_path=None): + # specify init path + init = torch.load(model_path, map_location="cpu")["state_dict"] + dynamic_net.load_state_dict(init) + run_manager.write_log("Loaded init from %s" % model_path, "valid") + + +def train_elastic_depth(train_func, run_manager, args, validate_func_dict): + dynamic_net = run_manager.net + if isinstance(dynamic_net, nn.DataParallel): + dynamic_net = dynamic_net.module + + depth_stage_list = dynamic_net.depth_list.copy() + depth_stage_list.sort(reverse=True) + n_stages = len(depth_stage_list) - 1 + current_stage = n_stages - 1 + + # load pretrained models + if run_manager.start_epoch == 0 and not args.resume: + validate_func_dict["depth_list"] = sorted(dynamic_net.depth_list) + + load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path) + # validate after loading weights + run_manager.write_log( + "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" + % validate(run_manager, is_test=True, **validate_func_dict), + "valid", + ) + else: + assert args.resume + + run_manager.write_log( + "-" * 30 + + "Supporting Elastic Depth: %s -> %s" + % (depth_stage_list[: current_stage + 1], depth_stage_list[: current_stage + 2]) + + "-" * 30, + "valid", + ) + # add depth list constraints + if ( + len(set(dynamic_net.ks_list)) == 1 + and len(set(dynamic_net.expand_ratio_list)) == 1 + ): + validate_func_dict["depth_list"] = depth_stage_list + else: + validate_func_dict["depth_list"] = sorted( + {min(depth_stage_list), max(depth_stage_list)} + ) + + # train + train_func( + run_manager, + args, + lambda _run_manager, epoch, is_test: validate( + _run_manager, epoch, is_test, **validate_func_dict + ), + ) + + +def train_elastic_expand(train_func, run_manager, args, validate_func_dict): + dynamic_net = run_manager.net + if isinstance(dynamic_net, nn.DataParallel): + dynamic_net = dynamic_net.module + + expand_stage_list = dynamic_net.expand_ratio_list.copy() + expand_stage_list.sort(reverse=True) + n_stages = len(expand_stage_list) - 1 + current_stage = n_stages - 1 + + # load pretrained models + if run_manager.start_epoch == 0 and not args.resume: + validate_func_dict["expand_ratio_list"] = sorted(dynamic_net.expand_ratio_list) + + load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path) + dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage) + run_manager.write_log( + "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" + % validate(run_manager, is_test=True, **validate_func_dict), + "valid", + ) + else: + assert args.resume + + run_manager.write_log( + "-" * 30 + + "Supporting Elastic Expand Ratio: %s -> %s" + % ( + expand_stage_list[: current_stage + 1], + expand_stage_list[: current_stage + 2], + ) + + "-" * 30, + "valid", + ) + if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1: + validate_func_dict["expand_ratio_list"] = expand_stage_list + else: + validate_func_dict["expand_ratio_list"] = sorted( + {min(expand_stage_list), max(expand_stage_list)} + ) + + # train + train_func( + run_manager, + args, + lambda _run_manager, epoch, is_test: validate( + _run_manager, epoch, is_test, **validate_func_dict + ), + ) + + +def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict): + dynamic_net = run_manager.net + if isinstance(dynamic_net, nn.DataParallel): + dynamic_net = dynamic_net.module + + width_stage_list = dynamic_net.width_mult_list.copy() + width_stage_list.sort(reverse=True) + n_stages = len(width_stage_list) - 1 + current_stage = n_stages - 1 + + if run_manager.start_epoch == 0 and not args.resume: + load_models(run_manager, dynamic_net, model_path=args.dyn_checkpoint_path) + if current_stage == 0: + dynamic_net.re_organize_middle_weights( + expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1 + ) + run_manager.write_log( + "reorganize_middle_weights (expand_ratio_stage=%d)" + % (len(dynamic_net.expand_ratio_list) - 1), + "valid", + ) + try: + dynamic_net.re_organize_outer_weights() + run_manager.write_log("reorganize_outer_weights", "valid") + except Exception: + pass + validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1}) + run_manager.write_log( + "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" + % validate(run_manager, is_test=True, **validate_func_dict), + "valid", + ) + else: + assert args.resume + + run_manager.write_log( + "-" * 30 + + "Supporting Elastic Width Mult: %s -> %s" + % (width_stage_list[: current_stage + 1], width_stage_list[: current_stage + 2]) + + "-" * 30, + "valid", + ) + validate_func_dict["width_mult_list"] = sorted({0, len(width_stage_list) - 1}) + + # train + train_func( + run_manager, + args, + lambda _run_manager, epoch, is_test: validate( + _run_manager, epoch, is_test, **validate_func_dict + ), + ) diff --git a/proard/classification/elastic_nn/utils.py b/proard/classification/elastic_nn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..755698fa65b02cbd1ad0fe31dc2d1c921479b764 --- /dev/null +++ b/proard/classification/elastic_nn/utils.py @@ -0,0 +1,83 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import copy +import torch.nn.functional as F +import torch.nn as nn +import torch +from attacks import create_attack +from attacks.utils import ctx_noparamgrad_and_eval +from proard.utils import AverageMeter, get_net_device, DistributedTensor +from proard.classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d + +__all__ = ["set_running_statistics"] + +def set_running_statistics(model, data_loader, distributed=False): + bn_mean = {} + bn_var = {} + + forward_model = copy.deepcopy(model) + for name, m in forward_model.named_modules(): + if isinstance(m, nn.BatchNorm2d): + if distributed: + bn_mean[name] = DistributedTensor(name + "#mean") + bn_var[name] = DistributedTensor(name + "#var") + else: + bn_mean[name] = AverageMeter() + bn_var[name] = AverageMeter() + + def new_forward(bn, mean_est, var_est): + def lambda_forward(x): + batch_mean = ( + x.mean(0, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) # 1, C, 1, 1 + batch_var = (x - batch_mean) * (x - batch_mean) + batch_var = ( + batch_var.mean(0, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) + + batch_mean = torch.squeeze(batch_mean) + batch_var = torch.squeeze(batch_var) + + mean_est.update(batch_mean.data, x.size(0)) + var_est.update(batch_var.data, x.size(0)) + + # bn forward using calculated mean & var + _feature_dim = batch_mean.size(0) + return F.batch_norm( + x, + batch_mean, + batch_var, + bn.weight[:_feature_dim], + bn.bias[:_feature_dim], + False, + 0.0, + bn.eps, + ) + + return lambda_forward + + m.forward = new_forward(m, bn_mean[name], bn_var[name]) + + if len(bn_mean) == 0: + # skip if there is no batch normalization layers in the network + return + + with torch.no_grad(): + DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True + for images, labels in data_loader: + images = images.to(get_net_device(forward_model)) + forward_model(images) + DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False + + for name, m in model.named_modules(): + if name in bn_mean and bn_mean[name].count > 0: + feature_dim = bn_mean[name].avg.size(0) + assert isinstance(m, nn.BatchNorm2d) + m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) + m.running_var.data[:feature_dim].copy_(bn_var[name].avg) diff --git a/proard/classification/networks/__init__.py b/proard/classification/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1fac8eb04b68f8e997e620b64cd758f3046a707b --- /dev/null +++ b/proard/classification/networks/__init__.py @@ -0,0 +1,25 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .proxyless_nets import * +from .mobilenet_v3 import * +from .resnets import * +from .wide_resnet import WideResNet +from .resnet_trades import * + +def get_net_by_name(name): + if name == ProxylessNASNets.__name__: + return ProxylessNASNets + elif name == MobileNetV3.__name__: + return MobileNetV3 + elif name == ResNets.__name__: + return ResNets + if name == ProxylessNASNets_Cifar.__name__: + return ProxylessNASNets_Cifar + elif name == MobileNetV3_Cifar.__name__: + return MobileNetV3 + elif name == ResNets_Cifar.__name__: + return ResNets_Cifar + else: + raise ValueError("unrecognized type of network: %s" % name) diff --git a/proard/classification/networks/mobilenet_v3.py b/proard/classification/networks/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..d6777407264b3cf65e7c53246c482c5c8b862f3d --- /dev/null +++ b/proard/classification/networks/mobilenet_v3.py @@ -0,0 +1,559 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import copy +import torch.nn as nn + +from proard.utils.layers import ( + set_layer_from_config, + MBConvLayer, + ConvLayer, + IdentityLayer, + LinearLayer, + ResidualBlock, +) +from proard.utils import MyNetwork, make_divisible, MyGlobalAvgPool2d + +__all__ = ["MobileNetV3", "MobileNetV3Large","MobileNetV3_Cifar", "MobileNetV3Large_Cifar"] + + +class MobileNetV3(MyNetwork): + def __init__( + self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ): + super(MobileNetV3, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.final_expand_layer = final_expand_layer + self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True) + self.feature_mix_layer = feature_mix_layer + self.classifier = classifier + + def forward(self, x): + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + x = self.final_expand_layer(x) + x = self.global_avg_pool(x) # global average pooling + x = self.feature_mix_layer(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.final_expand_layer.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": MobileNetV3.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "final_expand_layer": self.final_expand_layer.config, + "feature_mix_layer": self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + first_conv = set_layer_from_config(config["first_conv"]) + final_expand_layer = set_layer_from_config(config["final_expand_layer"]) + feature_mix_layer = set_layer_from_config(config["feature_mix_layer"]) + classifier = set_layer_from_config(config["classifier"]) + + blocks = [] + for block_config in config["blocks"]: + blocks.append(ResidualBlock.build_from_config(block_config)) + + net = MobileNetV3( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-5) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.conv, MBConvLayer) and isinstance( + m.shortcut, IdentityLayer + ): + m.conv.point_linear.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks[1:], 1): + if block.shortcut is None and len(block_index_list) > 0: + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + @staticmethod + def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate): + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=2, + use_bn=True, + act_func="h_swish", + ops_order="weight_bn_act", + ) + # build mobile blocks + feature_dim = input_channel + blocks = [] + for stage_id, block_config_list in cfg.items(): + for ( + k, + mid_channel, + out_channel, + use_se, + act_func, + stride, + expand_ratio, + ) in block_config_list: + mb_conv = MBConvLayer( + feature_dim, + out_channel, + k, + stride, + expand_ratio, + mid_channel, + act_func, + use_se, + ) + if stride == 1 and out_channel == feature_dim: + shortcut = IdentityLayer(out_channel, out_channel) + else: + shortcut = None + blocks.append(ResidualBlock(mb_conv, shortcut)) + feature_dim = out_channel + # final expand layer + final_expand_layer = ConvLayer( + feature_dim, + feature_dim * 6, + kernel_size=1, + use_bn=True, + act_func="h_swish", + ops_order="weight_bn_act", + ) + # feature mix layer + feature_mix_layer = ConvLayer( + feature_dim * 6, + last_channel, + kernel_size=1, + bias=False, + use_bn=False, + act_func="h_swish", + ) + # classifier + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + + @staticmethod + def adjust_cfg( + cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None + ): + for i, (stage_id, block_config_list) in enumerate(cfg.items()): + for block_config in block_config_list: + if ks is not None and stage_id != "0": + block_config[0] = ks + if expand_ratio is not None and stage_id != "0": + block_config[-1] = expand_ratio + block_config[1] = None + if stage_width_list is not None: + block_config[2] = stage_width_list[i] + if depth_param is not None and stage_id != "0": + new_block_config_list = [block_config_list[0]] + new_block_config_list += [ + copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1) + ] + cfg[stage_id] = new_block_config_list + return cfg + + def load_state_dict(self, state_dict, **kwargs): + current_state_dict = self.state_dict() + + for key in state_dict: + if key not in current_state_dict: + assert ".mobile_inverted_conv." in key + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + current_state_dict[new_key] = state_dict[key] + super(MobileNetV3, self).load_state_dict(current_state_dict) + + +class MobileNetV3Large(MobileNetV3): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0.2, + ks=None, + expand_ratio=None, + depth_param=None, + stage_width_list=None, + ): + input_channel = 16 + last_channel = 1280 + + input_channel = make_divisible( + input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = ( + make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + if width_mult > 1.0 + else last_channel + ) + + cfg = { + # k, exp, c, se, nl, s, e, + "0": [ + [3, 16, 16, False, "relu", 1, 1], + ], + "1": [ + [3, 64, 24, False, "relu", 2, None], # 4 + [3, 72, 24, False, "relu", 1, None], # 3 + ], + "2": [ + [5, 72, 40, True, "relu", 2, None], # 3 + [5, 120, 40, True, "relu", 1, None], # 3 + [5, 120, 40, True, "relu", 1, None], # 3 + ], + "3": [ + [3, 240, 80, False, "h_swish", 2, None], # 6 + [3, 200, 80, False, "h_swish", 1, None], # 2.5 + [3, 184, 80, False, "h_swish", 1, None], # 2.3 + [3, 184, 80, False, "h_swish", 1, None], # 2.3 + ], + "4": [ + [3, 480, 112, True, "h_swish", 1, None], # 6 + [3, 672, 112, True, "h_swish", 1, None], # 6 + ], + "5": [ + [5, 672, 160, True, "h_swish", 2, None], # 6 + [5, 960, 160, True, "h_swish", 1, None], # 6 + [5, 960, 160, True, "h_swish", 1, None], # 6 + ], + } + + cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list) + # width multiplier on mobile setting, change `exp: 1` and `c: 2` + for stage_id, block_config_list in cfg.items(): + for block_config in block_config_list: + if block_config[1] is not None: + block_config[1] = make_divisible( + block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + block_config[2] = make_divisible( + block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + ( + first_conv, + blocks, + final_expand_layer, + feature_mix_layer, + classifier, + ) = self.build_net_via_cfg( + cfg, input_channel, last_channel, n_classes, dropout_rate + ) + super(MobileNetV3Large, self).__init__( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + # set bn param + self.set_bn_param(*bn_param) + + + +class MobileNetV3_Cifar(MyNetwork): + def __init__( + self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ): + super(MobileNetV3_Cifar, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.final_expand_layer = final_expand_layer + self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True) + self.feature_mix_layer = feature_mix_layer + self.classifier = classifier + + def forward(self, x): + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + x = self.final_expand_layer(x) + x = self.global_avg_pool(x) # global average pooling + x = self.feature_mix_layer(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.final_expand_layer.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": MobileNetV3_Cifar.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "final_expand_layer": self.final_expand_layer.config, + "feature_mix_layer": self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + first_conv = set_layer_from_config(config["first_conv"]) + final_expand_layer = set_layer_from_config(config["final_expand_layer"]) + feature_mix_layer = set_layer_from_config(config["feature_mix_layer"]) + classifier = set_layer_from_config(config["classifier"]) + + blocks = [] + for block_config in config["blocks"]: + blocks.append(ResidualBlock.build_from_config(block_config)) + + net = MobileNetV3_Cifar( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-5) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.conv, MBConvLayer) and isinstance( + m.shortcut, IdentityLayer + ): + m.conv.point_linear.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks[1:], 1): + if block.shortcut is None and len(block_index_list) > 0: + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + @staticmethod + def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate): + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=1, + use_bn=True, + act_func="h_swish", + ops_order="weight_bn_act", + ) + # build mobile blocks + feature_dim = input_channel + blocks = [] + for stage_id, block_config_list in cfg.items(): + for ( + k, + mid_channel, + out_channel, + use_se, + act_func, + stride, + expand_ratio, + ) in block_config_list: + mb_conv = MBConvLayer( + feature_dim, + out_channel, + k, + stride, + expand_ratio, + mid_channel, + act_func, + use_se, + ) + if stride == 1 and out_channel == feature_dim: + shortcut = IdentityLayer(out_channel, out_channel) + else: + shortcut = None + blocks.append(ResidualBlock(mb_conv, shortcut)) + feature_dim = out_channel + # final expand layer + final_expand_layer = ConvLayer( + feature_dim, + feature_dim * 6, + kernel_size=1, + use_bn=True, + act_func="h_swish", + ops_order="weight_bn_act", + ) + # feature mix layer + feature_mix_layer = ConvLayer( + feature_dim * 6, + last_channel, + kernel_size=1, + bias=False, + use_bn=False, + act_func="h_swish", + ) + # classifier + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + + @staticmethod + def adjust_cfg( + cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None + ): + for i, (stage_id, block_config_list) in enumerate(cfg.items()): + for block_config in block_config_list: + if ks is not None and stage_id != "0": + block_config[0] = ks + if expand_ratio is not None and stage_id != "0": + block_config[-1] = expand_ratio + block_config[1] = None + if stage_width_list is not None: + block_config[2] = stage_width_list[i] + if depth_param is not None and stage_id != "0": + new_block_config_list = [block_config_list[0]] + new_block_config_list += [ + copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1) + ] + cfg[stage_id] = new_block_config_list + return cfg + + def load_state_dict(self, state_dict, **kwargs): + current_state_dict = self.state_dict() + + for key in state_dict: + if key not in current_state_dict: + assert ".mobile_inverted_conv." in key + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + current_state_dict[new_key] = state_dict[key] + super(MobileNetV3_Cifar, self).load_state_dict(current_state_dict) + + +class MobileNetV3Large_Cifar(MobileNetV3_Cifar): + def __init__( + self, + n_classes=10, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0.2, + ks=None, + expand_ratio=None, + depth_param=None, + stage_width_list=None, + ): + input_channel = 16 + last_channel = 1280 + + input_channel = make_divisible( + input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = ( + make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + if width_mult > 1.0 + else last_channel + ) + + cfg = { + # k, exp, c, se, nl, s, e, + "0": [ + [3, 16, 16, False, "relu", 1, 1], + ], + "1": [ + [3, 64, 24, False, "relu", 1, None], # 4 + [3, 72, 24, False, "relu", 1, None], # 3 + ], + "2": [ + [5, 72, 40, True, "relu", 2, None], # 3 + [5, 120, 40, True, "relu", 1, None], # 3 + [5, 120, 40, True, "relu", 1, None], # 3 + ], + "3": [ + [3, 240, 80, False, "h_swish", 2, None], # 6 + [3, 200, 80, False, "h_swish", 1, None], # 2.5 + [3, 184, 80, False, "h_swish", 1, None], # 2.3 + [3, 184, 80, False, "h_swish", 1, None], # 2.3 + ], + "4": [ + [3, 480, 112, True, "h_swish", 1, None], # 6 + [3, 672, 112, True, "h_swish", 1, None], # 6 + ], + "5": [ + [5, 672, 160, True, "h_swish", 2, None], # 6 + [5, 960, 160, True, "h_swish", 1, None], # 6 + [5, 960, 160, True, "h_swish", 1, None], # 6 + ], + } + + cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list) + # width multiplier on mobile setting, change `exp: 1` and `c: 2` + for stage_id, block_config_list in cfg.items(): + for block_config in block_config_list: + if block_config[1] is not None: + block_config[1] = make_divisible( + block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + block_config[2] = make_divisible( + block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + ( + first_conv, + blocks, + final_expand_layer, + feature_mix_layer, + classifier, + ) = self.build_net_via_cfg( + cfg, input_channel, last_channel, n_classes, dropout_rate + ) + super(MobileNetV3Large_Cifar, self).__init__( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + # set bn param + self.set_bn_param(*bn_param) diff --git a/proard/classification/networks/proxyless_nets.py b/proard/classification/networks/proxyless_nets.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c53d00c0deb712508f55c3baec5753312644d2 --- /dev/null +++ b/proard/classification/networks/proxyless_nets.py @@ -0,0 +1,490 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import json +import torch.nn as nn + +from proard.utils.layers import ( + set_layer_from_config, + MBConvLayer, + ConvLayer, + IdentityLayer, + LinearLayer, + ResidualBlock, +) +from proard.utils import ( + download_url, + make_divisible, + val2list, + MyNetwork, + MyGlobalAvgPool2d, +) + +__all__ = ["proxyless_base_cifar","proxyless_base", "ProxylessNASNets", "MobileNetV2", "ProxylessNASNets_Cifar", "MobileNetV2_Cifar"] + + +def proxyless_base( + net_config=None, + n_classes=None, + bn_param=None, + dropout_rate=None, + local_path="~/.torch/proxylessnas/", +): + assert net_config is not None, "Please input a network config" + if "http" in net_config: + net_config_path = download_url(net_config, local_path) + else: + net_config_path = net_config + net_config_json = json.load(open(net_config_path, "r")) + + if n_classes is not None: + net_config_json["classifier"]["out_features"] = n_classes + if dropout_rate is not None: + net_config_json["classifier"]["dropout_rate"] = dropout_rate + + net = ProxylessNASNets.build_from_config(net_config_json) + if bn_param is not None: + net.set_bn_param(*bn_param) + + return net + + +class ProxylessNASNets(MyNetwork): + def __init__(self, first_conv, blocks, feature_mix_layer, classifier): + super(ProxylessNASNets, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.feature_mix_layer = feature_mix_layer + self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False) + self.classifier = classifier + + def forward(self, x): + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + if self.feature_mix_layer is not None: + x = self.feature_mix_layer(x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": ProxylessNASNets.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "feature_mix_layer": None + if self.feature_mix_layer is None + else self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + first_conv = set_layer_from_config(config["first_conv"]) + feature_mix_layer = set_layer_from_config(config["feature_mix_layer"]) + classifier = set_layer_from_config(config["classifier"]) + + blocks = [] + for block_config in config["blocks"]: + blocks.append(ResidualBlock.build_from_config(block_config)) + + net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-3) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.conv, MBConvLayer) and isinstance( + m.shortcut, IdentityLayer + ): + m.conv.point_linear.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks[1:], 1): + if block.shortcut is None and len(block_index_list) > 0: + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + def load_state_dict(self, state_dict, **kwargs): + current_state_dict = self.state_dict() + + for key in state_dict: + if key not in current_state_dict: + assert ".mobile_inverted_conv." in key + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + current_state_dict[new_key] = state_dict[key] + super(ProxylessNASNets, self).load_state_dict(current_state_dict) + + +class MobileNetV2(ProxylessNASNets): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-3), + dropout_rate=0.2, + ks=None, + expand_ratio=None, + depth_param=None, + stage_width_list=None, + ): + + ks = 3 if ks is None else ks + expand_ratio = 6 if expand_ratio is None else expand_ratio + + input_channel = 32 + last_channel = 1280 + + input_channel = make_divisible( + input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = ( + make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + if width_mult > 1.0 + else last_channel + ) + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [expand_ratio, 24, 2, 2], + [expand_ratio, 32, 3, 2], + [expand_ratio, 64, 4, 2], + [expand_ratio, 96, 3, 1], + [expand_ratio, 160, 3, 2], + [expand_ratio, 320, 1, 1], + ] + + if depth_param is not None: + assert isinstance(depth_param, int) + for i in range(1, len(inverted_residual_setting) - 1): + inverted_residual_setting[i][2] = depth_param + + if stage_width_list is not None: + for i in range(len(inverted_residual_setting)): + inverted_residual_setting[i][1] = stage_width_list[i] + + ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1) + _pt = 0 + + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=2, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + # inverted residual blocks + blocks = [] + for t, c, n, s in inverted_residual_setting: + output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + for i in range(n): + if i == 0: + stride = s + else: + stride = 1 + if t == 1: + kernel_size = 3 + else: + kernel_size = ks[_pt] + _pt += 1 + mobile_inverted_conv = MBConvLayer( + in_channels=input_channel, + out_channels=output_channel, + kernel_size=kernel_size, + stride=stride, + expand_ratio=t, + ) + if stride == 1: + if input_channel == output_channel: + shortcut = IdentityLayer(input_channel, input_channel) + else: + shortcut = ConvLayer(input_channel,output_channel,kernel_size=1,stride=1,bias=False,act_func=None) + else: + shortcut = None + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + input_channel = output_channel + # 1x1_conv before global average pooling + feature_mix_layer = ConvLayer( + input_channel, + last_channel, + kernel_size=1, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(MobileNetV2, self).__init__( + first_conv, blocks, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(*bn_param) + + + +def proxyless_base_cifar( + net_config=None, + n_classes=None, + bn_param=None, + dropout_rate=None, + local_path="~/.torch/proxylessnas/", +): + assert net_config is not None, "Please input a network config" + if "http" in net_config: + net_config_path = download_url(net_config, local_path) + else: + net_config_path = net_config + net_config_json = json.load(open(net_config_path, "r")) + + if n_classes is not None: + net_config_json["classifier"]["out_features"] = n_classes + if dropout_rate is not None: + net_config_json["classifier"]["dropout_rate"] = dropout_rate + + net = ProxylessNASNets_Cifar.build_from_config(net_config_json) + if bn_param is not None: + net.set_bn_param(*bn_param) + + return net + + +class ProxylessNASNets_Cifar(MyNetwork): + def __init__(self, first_conv, blocks, feature_mix_layer, classifier): + super(ProxylessNASNets_Cifar, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.feature_mix_layer = feature_mix_layer + self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False) + self.classifier = classifier + + def forward(self, x): + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + if self.feature_mix_layer is not None: + x = self.feature_mix_layer(x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": ProxylessNASNets_Cifar.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "feature_mix_layer": None + if self.feature_mix_layer is None + else self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + first_conv = set_layer_from_config(config["first_conv"]) + feature_mix_layer = set_layer_from_config(config["feature_mix_layer"]) + classifier = set_layer_from_config(config["classifier"]) + + blocks = [] + for block_config in config["blocks"]: + blocks.append(ResidualBlock.build_from_config(block_config)) + + net = ProxylessNASNets_Cifar(first_conv, blocks, feature_mix_layer, classifier) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-3) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.conv, MBConvLayer) and isinstance( + m.shortcut, IdentityLayer + ): + m.conv.point_linear.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks[1:], 1): + if block.shortcut is None and len(block_index_list) > 0: + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + def load_state_dict(self, state_dict, **kwargs): + current_state_dict = self.state_dict() + + for key in state_dict: + if key not in current_state_dict: + assert ".mobile_inverted_conv." in key + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + current_state_dict[new_key] = state_dict[key] + super(ProxylessNASNets_Cifar, self).load_state_dict(current_state_dict) + + +class MobileNetV2_Cifar(ProxylessNASNets_Cifar): + def __init__( + self, + n_classes=10, + width_mult=1.0, + bn_param=(0.1, 1e-3), + dropout_rate=0.2, + ks=None, + expand_ratio=None, + depth_param=None, + stage_width_list=None, + ): + + ks = 3 if ks is None else ks + expand_ratio = 6 if expand_ratio is None else expand_ratio + + input_channel = 32 + last_channel = 1280 + + input_channel = make_divisible( + input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + last_channel = ( + make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + if width_mult > 1.0 + else last_channel + ) + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [expand_ratio, 24, 2, 1], + [expand_ratio, 32, 3, 2], + [expand_ratio, 64, 4, 2], + [expand_ratio, 96, 3, 1], + [expand_ratio, 160, 3, 2], + [expand_ratio, 320, 1, 1], + ] + + if depth_param is not None: + assert isinstance(depth_param, int) + for i in range(1, len(inverted_residual_setting) - 1): + inverted_residual_setting[i][2] = depth_param + + if stage_width_list is not None: + for i in range(len(inverted_residual_setting)): + inverted_residual_setting[i][1] = stage_width_list[i] + + ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1) + _pt = 0 + + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=1, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + # inverted residual blocks + blocks = [] + for t, c, n, s in inverted_residual_setting: + output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + for i in range(n): + if i == 0: + stride = s + else: + stride = 1 + if t == 1: + kernel_size = 3 + else: + kernel_size = ks[_pt] + _pt += 1 + mobile_inverted_conv = MBConvLayer( + in_channels=input_channel, + out_channels=output_channel, + kernel_size=kernel_size, + stride=stride, + expand_ratio=t, + ) + if stride == 1: + if input_channel == output_channel: + shortcut = IdentityLayer(input_channel, input_channel) + else: + shortcut = None #ConvLayer(input_channel,output_channel,kernel_size=1,stride=1,bias=False,act_func=None) + else: + shortcut = None + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + input_channel = output_channel + # 1x1_conv before global average pooling + feature_mix_layer = ConvLayer( + input_channel, + last_channel, + kernel_size=1, + stride=1, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(MobileNetV2_Cifar, self).__init__( + first_conv, blocks, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(*bn_param) diff --git a/proard/classification/networks/resnet_trades.py b/proard/classification/networks/resnet_trades.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ceba42d4aedfcc1c1fd9e61cb1a48cda27b549 --- /dev/null +++ b/proard/classification/networks/resnet_trades.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + +from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d +class ResNet(MyNetwork): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18_trades(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet34_trades(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50_trades(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101_trades(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152_trades(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def test(): + net = ResNet18_trades() + y = net(torch.randn(1, 3, 32, 32)) + print(y.size()) \ No newline at end of file diff --git a/proard/classification/networks/resnets.py b/proard/classification/networks/resnets.py new file mode 100644 index 0000000000000000000000000000000000000000..085da00528527350e6762f6cc09e4db4ecd8446e --- /dev/null +++ b/proard/classification/networks/resnets.py @@ -0,0 +1,490 @@ +import torch.nn as nn + +from proard.utils.layers import ( + set_layer_from_config, + ConvLayer, + IdentityLayer, + LinearLayer, +) +from proard.utils.layers import ResNetBottleneckBlock, ResidualBlock +from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d + +__all__ = ["ResNets", "ResNet50", "ResNet50D","ResNets_Cifar","ResNet50_Cifar", "ResNet50D_Cifar"] + + +class ResNets(MyNetwork): + BASE_DEPTH_LIST = [2, 2, 4, 2] + STAGE_WIDTH_LIST = [256, 512, 1024, 2048] + + def __init__(self, input_stem, blocks, classifier): + super(ResNets, self).__init__() + + self.input_stem = nn.ModuleList(input_stem) + self.max_pooling = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False + ) + self.blocks = nn.ModuleList(blocks) + self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False) + self.classifier = classifier + + def forward(self, x): + for layer in self.input_stem: + x = layer(x) + x = self.max_pooling(x) + for block in self.blocks: + x = block(x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = "" + for layer in self.input_stem: + _str += layer.module_str + "\n" + _str += "max_pooling(ks=3, stride=2)\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": ResNets.__name__, + "bn": self.get_bn_param(), + "input_stem": [layer.config for layer in self.input_stem], + "blocks": [block.config for block in self.blocks], + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + classifier = set_layer_from_config(config["classifier"]) + + input_stem = [] + for layer_config in config["input_stem"]: + input_stem.append(set_layer_from_config(layer_config)) + blocks = [] + for block_config in config["blocks"]: + blocks.append(set_layer_from_config(block_config)) + + net = ResNets(input_stem, blocks, classifier) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-5) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResNetBottleneckBlock) and isinstance( + m.downsample, IdentityLayer + ): + m.conv3.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks): + if ( + not isinstance(block.downsample, IdentityLayer) + and len(block_index_list) > 0 + ): + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + def load_state_dict(self, state_dict, **kwargs): + super(ResNets, self).load_state_dict(state_dict) + + +class ResNet50(ResNets): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0, + expand_ratio=None, + depth_param=None, + ): + + expand_ratio = 0.25 if expand_ratio is None else expand_ratio + + input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + stage_width_list = ResNets.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = make_divisible( + width * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + depth_list = [3, 4, 6, 3] + if depth_param is not None: + for i, depth in enumerate(ResNets.BASE_DEPTH_LIST): + depth_list[i] = depth + depth_param + + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + ConvLayer( + 3, + input_channel, + kernel_size=7, + stride=2, + use_bn=True, + act_func="relu", + ops_order="weight_bn_act", + ) + ] + + # blocks + blocks = [] + for d, width, s in zip(depth_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = ResNetBottleneckBlock( + input_channel, + width, + kernel_size=3, + stride=stride, + expand_ratio=expand_ratio, + act_func="relu", + downsample_mode="conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) + + super(ResNet50, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) + + +class ResNet50D(ResNets): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0, + expand_ratio=None, + depth_param=None, + ): + + expand_ratio = 0.25 if expand_ratio is None else expand_ratio + + input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + mid_input_channel = make_divisible( + input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE + ) + stage_width_list = ResNets.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = make_divisible( + width * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + depth_list = [3, 4, 6, 3] + if depth_param is not None: + for i, depth in enumerate(ResNets.BASE_DEPTH_LIST): + depth_list[i] = depth + depth_param + + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func="relu"), + ResidualBlock( + ConvLayer( + mid_input_channel, + mid_input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + IdentityLayer(mid_input_channel, mid_input_channel), + ), + ConvLayer( + mid_input_channel, + input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + ] + + # blocks + blocks = [] + for d, width, s in zip(depth_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = ResNetBottleneckBlock( + input_channel, + width, + kernel_size=3, + stride=stride, + expand_ratio=expand_ratio, + act_func="relu", + downsample_mode="avgpool_conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) + + super(ResNet50D, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) + + + +class ResNets_Cifar(MyNetwork): + + BASE_DEPTH_LIST = [2, 2, 4, 2] + STAGE_WIDTH_LIST = [256, 512, 1024, 2048] + + def __init__(self, input_stem, blocks, classifier): + super(ResNets_Cifar, self).__init__() + + self.input_stem = nn.ModuleList(input_stem) + self.blocks = nn.ModuleList(blocks) + self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False) + self.classifier = classifier + + def forward(self, x): + for layer in self.input_stem: + x = layer(x) + for block in self.blocks: + x = block(x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = "" + for layer in self.input_stem: + _str += layer.module_str + "\n" + # _str += "max_pooling(ks=3, stride=2)\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": ResNets_Cifar.__name__, + "bn": self.get_bn_param(), + "input_stem": [layer.config for layer in self.input_stem], + "blocks": [block.config for block in self.blocks], + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + classifier = set_layer_from_config(config["classifier"]) + + input_stem = [] + for layer_config in config["input_stem"]: + input_stem.append(set_layer_from_config(layer_config)) + blocks = [] + for block_config in config["blocks"]: + blocks.append(set_layer_from_config(block_config)) + + net = ResNets(input_stem, blocks, classifier) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-5) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResNetBottleneckBlock) and isinstance( + m.downsample, IdentityLayer + ): + m.conv3.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks): + if ( + not isinstance(block.downsample, IdentityLayer) + and len(block_index_list) > 0 + ): + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + def load_state_dict(self, state_dict, **kwargs): + super(ResNets_Cifar, self).load_state_dict(state_dict) + + +class ResNet50_Cifar(ResNets_Cifar): + def __init__( + self, + n_classes=10, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0, + expand_ratio=None, + depth_param=None, + ): + + expand_ratio = 0.25 if expand_ratio is None else expand_ratio + + input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + stage_width_list = ResNets_Cifar.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = make_divisible( + width * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + depth_list = [3, 4, 6, 3] + if depth_param is not None: + for i, depth in enumerate(ResNets_Cifar.BASE_DEPTH_LIST): + depth_list[i] = depth + depth_param + + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=1, + use_bn=True, + act_func="relu", + ops_order="weight_bn_act", + ) + ] + + # blocks + blocks = [] + for d, width, s in zip(depth_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = ResNetBottleneckBlock( + input_channel, + width, + kernel_size=3, + stride=stride, + expand_ratio=expand_ratio, + act_func="relu", + downsample_mode="conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) + + super(ResNet50_Cifar, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) + + +class ResNet50D_Cifar(ResNets_Cifar): + def __init__( + self, + n_classes=10, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0, + expand_ratio=None, + depth_param=None, + ): + + expand_ratio = 0.25 if expand_ratio is None else expand_ratio + + input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) + mid_input_channel = make_divisible( + input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE + ) + stage_width_list = ResNets.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = make_divisible( + width * width_mult, MyNetwork.CHANNEL_DIVISIBLE + ) + + depth_list = [3, 4, 6, 3] + if depth_param is not None: + for i, depth in enumerate(ResNets.BASE_DEPTH_LIST): + depth_list[i] = depth + depth_param + + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + ConvLayer(3, mid_input_channel, 3, stride=1, use_bn=True, act_func="relu"), + ResidualBlock( + ConvLayer( + mid_input_channel, + mid_input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + IdentityLayer(mid_input_channel, mid_input_channel), + ), + ConvLayer( + mid_input_channel, + input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + ] + + # blocks + blocks = [] + for d, width, s in zip(depth_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = ResNetBottleneckBlock( + input_channel, + width, + kernel_size=3, + stride=stride, + expand_ratio=expand_ratio, + act_func="relu", + downsample_mode="avgpool_conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) + + super(ResNet50D_Cifar, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) +if __name__=="__main__": + import torch + resnet = ResNet50_Cifar() + x = torch.randn((1,3,32,32)) + resnet(x) + diff --git a/proard/classification/networks/wide_resnet.py b/proard/classification/networks/wide_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f60adfaf3b74508f0e3cd932484c3e5016b511f4 --- /dev/null +++ b/proard/classification/networks/wide_resnet.py @@ -0,0 +1,93 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d + +class BasicBlock(nn.Module): + def __init__(self, in_planes, out_planes, stride, dropRate=0.0): + super(BasicBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.relu1 = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_planes) + self.relu2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + self.equalInOut = (in_planes == out_planes) + self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) or None + + def forward(self, x): + if not self.equalInOut: + x = self.relu1(self.bn1(x)) + else: + out = self.relu1(self.bn1(x)) + out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, training=self.training) + out = self.conv2(out) + return torch.add(x if self.equalInOut else self.convShortcut(x), out) + + +class NetworkBlock(nn.Module): + def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): + super(NetworkBlock, self).__init__() + self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) + + def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): + layers = [] + for i in range(int(nb_layers)): + layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) + return nn.Sequential(*layers) + + def forward(self, x): + return self.layer(x) + + +class WideResNet(MyNetwork): + def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): + super(WideResNet, self).__init__() + nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] + assert ((depth - 4) % 6 == 0) + n = (depth - 4) / 6 + block = BasicBlock + # 1st conv before any network block + self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, + padding=1, bias=False) + # 1st block + self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) + # 1st sub-block + self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) + # 2nd block + self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) + # 3rd block + self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) + # global average pooling and classifier + self.bn1 = nn.BatchNorm2d(nChannels[3]) + self.relu = nn.ReLU(inplace=True) + self.fc = nn.Linear(nChannels[3], num_classes) + self.nChannels = nChannels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def forward(self, x): + out = self.conv1(x) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.relu(self.bn1(out)) + out = F.avg_pool2d(out, 8) + out = out.view(-1, self.nChannels) + return self.fc(out) + \ No newline at end of file diff --git a/proard/classification/run_manager/__init__.py b/proard/classification/run_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57a83b548c41f5400d65866f2313b2295131cefc --- /dev/null +++ b/proard/classification/run_manager/__init__.py @@ -0,0 +1,7 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .run_config import * +from .run_manager import * +from .distributed_run_manager import * diff --git a/proard/classification/run_manager/distributed_run_manager.py b/proard/classification/run_manager/distributed_run_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcf5619798d3cf0ba3c190f7412edf157904422 --- /dev/null +++ b/proard/classification/run_manager/distributed_run_manager.py @@ -0,0 +1,505 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import json +import time +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from attacks import create_attack +import torch.backends.cudnn as cudnn +from tqdm import tqdm +from attacks.utils import ctx_noparamgrad_and_eval +from proard.utils import ( + cross_entropy_with_label_smoothing, + cross_entropy_loss_with_soft_target, + write_log, + init_models, +) +from proard.utils import ( + DistributedMetric, + list_mean, + get_net_info, + accuracy, + AverageMeter, + mix_labels, + mix_images, +) +from proard.utils import MyRandomResizedCrop + +__all__ = ["DistributedRunManager"] + + +class DistributedRunManager: + def __init__( + self, + path, + net, + run_config, + hvd_compression, + backward_steps=1, + is_root=False, + init=True, + ): + import horovod.torch as hvd + + self.path = path + self.net = net + self.run_config = run_config + self.is_root = is_root + + self.best_acc = 0.0 + self.best_robustness = 0.0 + self.start_epoch = 0 + + os.makedirs(self.path, exist_ok=True) + + self.net.cuda() + cudnn.benchmark = True + if init and self.is_root: + init_models(self.net, self.run_config.model_init) + if self.is_root: + # print net info + net_info = get_net_info(self.net, self.run_config.data_provider.data_shape) + with open("%s/net_info.txt" % self.path, "w") as fout: + fout.write(json.dumps(net_info, indent=4) + "\n") + try: + fout.write(self.net.module_str + "\n") + except Exception: + fout.write("%s do not support `module_str`" % type(self.net)) + fout.write( + "%s\n" % self.run_config.data_provider.train.dataset.transform + ) + fout.write( + "%s\n" % self.run_config.data_provider.test.dataset.transform + ) + fout.write("%s\n" % self.net) + + # criterion + self.train_criterion = self.run_config.train_criterion_loss + self.test_criterion = self.run_config.test_criterion_loss + self.kd_criterion = self.run_config.kd_criterion_loss + + # optimizer + if self.run_config.no_decay_keys: + keys = self.run_config.no_decay_keys.split("#") + net_params = [ + self.net.get_parameters( + keys, mode="exclude" + ), # parameters with weight decay + self.net.get_parameters( + keys, mode="include" + ), # parameters without weight decay + ] + else: + # noinspection PyBroadException + try: + net_params = self.network.weight_parameters() + except Exception: + net_params = [] + for param in self.network.parameters(): + if param.requires_grad: + net_params.append(param) + self.optimizer = self.run_config.build_optimizer(net_params) + self.optimizer = hvd.DistributedOptimizer( + self.optimizer, + named_parameters=self.net.named_parameters(), + compression=hvd_compression, + backward_passes_per_step=backward_steps, + ) + + """ save path and log path """ + + @property + def save_path(self): + if self.__dict__.get("_save_path", None) is None: + save_path = os.path.join(self.path, "checkpoint") + os.makedirs(save_path, exist_ok=True) + self.__dict__["_save_path"] = save_path + return self.__dict__["_save_path"] + + @property + def logs_path(self): + if self.__dict__.get("_logs_path", None) is None: + logs_path = os.path.join(self.path, "logs") + os.makedirs(logs_path, exist_ok=True) + self.__dict__["_logs_path"] = logs_path + return self.__dict__["_logs_path"] + + @property + def network(self): + return self.net + + @network.setter + def network(self, new_val): + self.net = new_val + + def write_log(self, log_str, prefix="valid", should_print=True, mode="a"): + if self.is_root: + write_log(self.logs_path, log_str, prefix, should_print, mode) + + """ save & load model & save_config & broadcast """ + + def save_config(self, extra_run_config=None, extra_net_config=None): + if self.is_root: + run_save_path = os.path.join(self.path, "run.config") + if not os.path.isfile(run_save_path): + run_config = self.run_config.config + if extra_run_config is not None: + run_config.update(extra_run_config) + json.dump(run_config, open(run_save_path, "w"), indent=4) + print("Run configs dump to %s" % run_save_path) + + try: + net_save_path = os.path.join(self.path, "net.config") + net_config = self.net.config + if extra_net_config is not None: + net_config.update(extra_net_config) + json.dump(net_config, open(net_save_path, "w"), indent=4) + print("Network configs dump to %s" % net_save_path) + except Exception: + print("%s do not support net config" % type(self.net)) + + def save_model(self, checkpoint=None, is_best=False, model_name=None): + if self.is_root: + if checkpoint is None: + checkpoint = {"state_dict": self.net.state_dict()} + + if model_name is None: + model_name = "checkpoint.pth.tar" + + latest_fname = os.path.join(self.save_path, "latest.txt") + model_path = os.path.join(self.save_path, model_name) + with open(latest_fname, "w") as _fout: + _fout.write(model_path + "\n") + torch.save(checkpoint, model_path) + + if is_best: + best_path = os.path.join(self.save_path, "model_best.pth.tar") + torch.save({"state_dict": checkpoint["state_dict"]}, best_path) + + def load_model(self, model_fname=None): + if self.is_root: + latest_fname = os.path.join(self.save_path, "latest.txt") + if model_fname is None and os.path.exists(latest_fname): + with open(latest_fname, "r") as fin: + model_fname = fin.readline() + if model_fname[-1] == "\n": + model_fname = model_fname[:-1] + # noinspection PyBroadException + try: + if model_fname is None or not os.path.exists(model_fname): + model_fname = "%s/checkpoint.pth.tar" % self.save_path + with open(latest_fname, "w") as fout: + fout.write(model_fname + "\n") + print("=> loading checkpoint '{}'".format(model_fname)) + checkpoint = torch.load(model_fname, map_location="cpu") + except Exception: + self.write_log( + "fail to load checkpoint from %s" % self.save_path, "valid" + ) + return + + self.net.load_state_dict(checkpoint["state_dict"]) + if "epoch" in checkpoint: + self.start_epoch = checkpoint["epoch"] + 1 + if "best_acc" in checkpoint: + self.best_acc = checkpoint["best_acc"] + if "optimizer" in checkpoint: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + + self.write_log("=> loaded checkpoint '{}'".format(model_fname), "valid") + + # noinspection PyArgumentList + def broadcast(self): + import horovod.torch as hvd + + self.start_epoch = hvd.broadcast( + torch.LongTensor(1).fill_(self.start_epoch)[0], 0, name="start_epoch" + ).item() + self.best_acc = hvd.broadcast( + torch.Tensor(1).fill_(self.best_acc)[0], 0, name="best_acc" + ).item() + hvd.broadcast_parameters(self.net.state_dict(), 0) + hvd.broadcast_optimizer_state(self.optimizer, 0) + + """ metric related """ + + def get_metric_dict(self): + return { + "top1": DistributedMetric("top1"), + "top5": DistributedMetric("top5"), + "robust1" : DistributedMetric("robust1"), + "robust5": DistributedMetric("robust5") + } + + def update_metric(self, metric_dict, output, output_adv , labels): + acc1, acc5 = accuracy(output, labels, topk=(1, 5)) + robust1, robust5 = accuracy(output_adv, labels, topk=(1, 5)) + metric_dict["top1"].update(acc1[0], output.size(0)) + metric_dict["top5"].update(acc5[0], output.size(0)) + metric_dict["robust1"].update(robust1[0], output.size(0)) + metric_dict["robust5"].update(robust5[0], output.size(0)) + + def get_metric_vals(self, metric_dict, return_dict=False): + if return_dict: + return {key: metric_dict[key].avg.item() for key in metric_dict} + else: + return [metric_dict[key].avg.item() for key in metric_dict] + + def get_metric_names(self): + return "top1", "top5", "robust1" ,"robust5" + + """ train & validate """ + + def validate( + self, + epoch=0, + is_test=False, + run_str="", + net=None, + data_loader=None, + no_logs=False, + ): + if net is None: + net = self.net + if data_loader is None: + if is_test: + data_loader = self.run_config.test_loader + else: + data_loader = self.run_config.valid_loader + + net.eval() + if self.run_config.robust_mode: + eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test) + losses = DistributedMetric("val_loss") + metric_dict = self.get_metric_dict() + + with tqdm( + total=len(data_loader), + desc="Validate Epoch #{} {}".format(epoch + 1, run_str), + disable=no_logs or not self.is_root, + ) as t: + for i, (images, labels) in enumerate(data_loader): + images, labels = images.cuda(), labels.cuda() + # compute output + output = net(images) + if self.run_config.robust_mode: + with ctx_noparamgrad_and_eval(net): + images_adv,_ = eval_attack.perturb(images, labels) + output_adv = net(images_adv) + loss = self.test_criterion(output_adv,labels) + else: + output_adv = output + loss = self.test_criterion(output,labels) + + # measure accuracy and record loss + losses.update(loss, images.size(0)) + self.update_metric(metric_dict, output, output_adv, labels) + t.set_postfix( + { + "loss": losses.avg.item(), + **self.get_metric_vals(metric_dict, return_dict=True), + "img_size": images.size(2), + } + ) + t.update(1) + return losses.avg.item(), self.get_metric_vals(metric_dict) + + def validate_all_resolution(self, epoch=0, is_test=False, net=None): + if net is None: + net = self.net + if isinstance(self.run_config.data_provider.image_size, list): + img_size_list, loss_list, top1_list, top5_list ,robust1_list, robust5_list = [], [], [], [],[],[] + for img_size in self.run_config.data_provider.image_size: + img_size_list.append(img_size) + self.run_config.data_provider.assign_active_img_size(img_size) + self.reset_running_statistics(net=net) # I am not sure that this is good fot robustness or not + loss, (top1, top5 ,robust1, robust5) = self.validate(epoch, is_test, net=net) + loss_list.append(loss) + top1_list.append(top1) + top5_list.append(top5) + robust1_list.append(robust1) + robust5_list.append(robust5) + + return img_size_list, loss_list, top1_list, top5_list,robust1_list,robust5_list + else: + self.reset_running_statistics(net=net) + loss, (top1, top5 , robust1 ,robust5) = self.validate(epoch, is_test, net=net) + return ( + [self.run_config.data_provider.active_img_size], + [loss], + [top1], + [top5], + [robust1], + [robust5], + ) + + def train_one_epoch(self, args, epoch, warmup_epochs=5, warmup_lr=0): + self.net.train() + self.run_config.train_loader.sampler.set_epoch( + epoch + ) # required by distributed sampler + MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution + + nBatch = len(self.run_config.train_loader) + + losses = DistributedMetric("train_loss") + metric_dict = self.get_metric_dict() + data_time = AverageMeter() + + with tqdm( + total=nBatch, + desc="Train Epoch #{}".format(epoch + 1), + disable=not self.is_root, + ) as t: + end = time.time() + for i, (images, labels) in enumerate(self.run_config.train_loader): + MyRandomResizedCrop.BATCH = i + data_time.update(time.time() - end) + if epoch < warmup_epochs: + new_lr = self.run_config.warmup_adjust_learning_rate( + self.optimizer, + warmup_epochs * nBatch, + nBatch, + epoch, + i, + warmup_lr, + ) + else: + new_lr = self.run_config.adjust_learning_rate( + self.optimizer, epoch - warmup_epochs, i, nBatch + ) + + images, labels = images.cuda(), labels.cuda() + target = labels + if isinstance(self.run_config.mixup_alpha, float): + # transform data + random.seed(int("%d%.3d" % (i, epoch))) + lam = random.betavariate( + self.run_config.mixup_alpha, self.run_config.mixup_alpha + ) + images = mix_images(images, lam) + labels = mix_labels( + labels, + lam, + self.run_config.data_provider.n_classes, + self.run_config.label_smoothing, + ) + + # soft target + if args.teacher_model is not None: + args.teacher_model.train() + with torch.no_grad(): + soft_logits = args.teacher_model(images).detach() + soft_label = F.softmax(soft_logits, dim=1) + + # compute output + output = self.net(images) + if args.teacher_model is None: + if self.run_config.robust_mode: + loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train) + loss_type = self.train_criterion.__name__ + else: + loss = torch.nn.CrossEntropyLoss(output,labels) + loss_type = 'ce' + + else: + if self.run_config.robust_mode: + loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train) + loss_type = self.kd_criterion_loss.__name__ + else: + if args.kd_type == "ce": + kd_loss = cross_entropy_loss_with_soft_target( + output, soft_label + ) + else: + kd_loss = F.mse_loss(output, soft_logits) + loss = args.kd_ratio * kd_loss + loss + loss_type = "%.1fkd+ce" % args.kd_ratio + + + # update + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # measure accuracy and record loss + losses.update(loss, images.size(0)) + self.update_metric(metric_dict, output, output, target) + + t.set_postfix( + { + "loss": losses.avg.item(), + **self.get_metric_vals(metric_dict, return_dict=True), + "img_size": images.size(2), + "lr": new_lr, + "loss_type": loss_type, + "data_time": data_time.avg, + } + ) + t.update(1) + end = time.time() + return losses.avg.item(), self.get_metric_vals(metric_dict) + + def train(self, args, warmup_epochs=5, warmup_lr=0): + for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs): + train_loss, (train_top1, train_top5, train_robust1, train_robust5) = self.train_one_epoch( + args, epoch, warmup_epochs, warmup_lr + ) + img_size, val_loss, val_top1, val_top5 , val_robust1, val_robust5= self.validate_all_resolution( + epoch, is_test=False + ) + + is_best = list_mean(val_top1) > self.best_acc + is_best_robust = list_mean(val_robust1) > self.best_robustness + self.best_robustness = max(self.best_robustness, list_mean(val_robust1)) + self.best_acc = max(self.best_acc, list_mean(val_top1)) + if self.is_root: + val_log = ( + "[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t {8} robust {10:.3f} ({4:.3f})\t{9} robust {11:.3f} " + "Train {6} {top1:.3f}\tloss {train_loss:.3f}\t robust1 {8} {robust1:.3f}\t".format( + epoch + 1 - warmup_epochs, + self.run_config.n_epochs, + list_mean(val_loss), + list_mean(val_top1), + self.best_acc, + list_mean(val_top5), + *self.get_metric_names(), + list_mean(val_robust1), + list_mean(val_robust5), + top1=train_top1, + train_loss=train_loss, + robust1 = train_robust1, + ) + ) + for i_s, v_a in zip(img_size, val_top1): + val_log += "(%d, %.3f), " % (i_s, v_a) + self.write_log(val_log, prefix="valid", should_print=False) + + self.save_model( + { + "epoch": epoch, + "best_acc": self.best_acc, + "optimizer": self.optimizer.state_dict(), + "state_dict": self.net.state_dict(), + }, + is_best=is_best, + ) + + def reset_running_statistics( + self, net=None, subset_size=4000, subset_batch_size=200, data_loader=None + ): + from proard.classification.elastic_nn.utils import set_running_statistics + + if net is None: + net = self.net + if data_loader is None: + data_loader = self.run_config.random_sub_train_loader( + subset_size, subset_batch_size + ) + + set_running_statistics(net, data_loader) diff --git a/proard/classification/run_manager/run_config.py b/proard/classification/run_manager/run_config.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc64869b480394033a40a0ad6f1c1307325afb9 --- /dev/null +++ b/proard/classification/run_manager/run_config.py @@ -0,0 +1,414 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from proard.utils import calc_learning_rate, build_optimizer +from proard.classification.data_providers import ImagenetDataProvider +from proard.classification.data_providers import Cifar10DataProvider +from proard.classification.data_providers import Cifar100DataProvider +from robust_loss.trades import trades_loss +from robust_loss.adaad import adaad_loss +from robust_loss.ard import ard_loss +from robust_loss.hat import hat_loss +from robust_loss.mart import mart_loss +from robust_loss.sat import sat_loss +from robust_loss.rslad import rslad_loss +import torch +__all__ = ["RunConfig", "ClassificationRunConfig", "DistributedClassificationRunConfig"] + + +class RunConfig: + def __init__( + self, + n_epochs, + init_lr, + lr_schedule_type, + lr_schedule_param, + dataset, + train_batch_size, + test_batch_size, + valid_size, + opt_type, + opt_param, + weight_decay, + label_smoothing, + no_decay_keys, + mixup_alpha, + model_init, + validation_frequency, + print_frequency, + ): + self.n_epochs = n_epochs + self.init_lr = init_lr + self.lr_schedule_type = lr_schedule_type + self.lr_schedule_param = lr_schedule_param + + self.dataset = dataset + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size + self.valid_size = valid_size + + self.opt_type = opt_type + self.opt_param = opt_param + self.weight_decay = weight_decay + self.label_smoothing = label_smoothing + self.no_decay_keys = no_decay_keys + + self.mixup_alpha = mixup_alpha + + self.model_init = model_init + self.validation_frequency = validation_frequency + self.print_frequency = print_frequency + + @property + def config(self): + config = {} + for key in self.__dict__: + if not key.startswith("_"): + config[key] = self.__dict__[key] + return config + + def copy(self): + return RunConfig(**self.config) + + """ learning rate """ + + def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None): + """adjust learning of a given optimizer and return the new learning rate""" + new_lr = calc_learning_rate( + epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type + ) + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + return new_lr + + def warmup_adjust_learning_rate( + self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0 + ): + T_cur = epoch * nBatch + batch + 1 + new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + return new_lr + + """ data provider """ + + @property + def data_provider(self): + raise NotImplementedError + + @property + def train_loader(self): + return self.data_provider.train + + @property + def valid_loader(self): + return self.data_provider.valid + + @property + def test_loader(self): + return self.data_provider.test + + def random_sub_train_loader( + self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None + ): + return self.data_provider.build_sub_train_loader( + n_images, batch_size, num_worker, num_replicas, rank + ) + + """ optimizer """ + + def build_optimizer(self, net_params): + return build_optimizer( + net_params, + self.opt_type, + self.opt_param, + self.init_lr, + self.weight_decay, + self.no_decay_keys, + ) + + + +class ClassificationRunConfig(RunConfig): + def __init__( + self, + n_epochs=150, + init_lr=0.05, + lr_schedule_type="cosine", + lr_schedule_param=None, + dataset="imagenet", # 'cifar10' or 'cifar100' + train_batch_size=256, + test_batch_size=500, + valid_size=None, + opt_type="sgd", + opt_param=None, + weight_decay=4e-5, + label_smoothing=0.1, + no_decay_keys=None, + mixup_alpha=None, + model_init="he_fout", + validation_frequency=1, + print_frequency=10, + n_worker=32, + resize_scale=0.08, + distort_color="tf", + image_size=224, # 32 + robust_mode = False, + epsilon_train = 0.031, + num_steps_train = 10, + step_size_train = 0.0078, + clip_min_train = 0 , + clip_max_train = 1, + const_init_train = False, + beta_train = 6.0, + distance_train ="l_inf", + epsilon_test = 0.031, + num_steps_test = 20, + step_size_test = 0.0078, + clip_min_test = 0, + clip_max_test = 1, + const_init_test = False, + beta_test = 6.0, + distance_test = "l_inf", + train_criterion = "trades", + test_criterion = "ce", + kd_criterion = 'rslad', + attack_type = "linf-pgd", + **kwargs + ): + super(ClassificationRunConfig, self).__init__( + n_epochs, + init_lr, + lr_schedule_type, + lr_schedule_param, + dataset, + train_batch_size, + test_batch_size, + valid_size, + opt_type, + opt_param, + weight_decay, + label_smoothing, + no_decay_keys, + mixup_alpha, + model_init, + validation_frequency, + print_frequency, + ) + + self.n_worker = n_worker + self.resize_scale = resize_scale + self.distort_color = distort_color + self.image_size = image_size + self.epsilon_train = epsilon_train + self.num_steps_train = num_steps_train + self.step_size_train = step_size_train + self.clip_min_train = clip_min_train + self.clip_max_train = clip_max_train + self.const_init_train = const_init_train + self.beta_train = beta_train + self.distance_train = distance_train + self.epsilon_test = epsilon_test + self.num_steps_test = num_steps_test + self.step_size_test = step_size_test + self.clip_min_test = clip_min_test + self.clip_max_test = clip_max_test + self.const_init_test = const_init_test + self.beta_test = beta_test + self.distance_test = distance_test + self.train_criterion = train_criterion + self.test_criterion = test_criterion + self.kd_criterion = kd_criterion + self.attack_type = attack_type + self.robust_mode = robust_mode + @property + def data_provider(self): + if self.__dict__.get("_data_provider", None) is None: + if self.dataset == ImagenetDataProvider.name(): + DataProviderClass = ImagenetDataProvider + elif self.dataset == Cifar10DataProvider.name(): + DataProviderClass = Cifar10DataProvider + elif self.dataset == Cifar100DataProvider.name(): + DataProviderClass = Cifar100DataProvider + else: + raise NotImplementedError + self.__dict__["_data_provider"] = DataProviderClass( + train_batch_size=self.train_batch_size, + test_batch_size=self.test_batch_size, + valid_size=self.valid_size, + n_worker=self.n_worker, + resize_scale=self.resize_scale, + distort_color=self.distort_color, + image_size=self.image_size, + ) + return self.__dict__["_data_provider"] + @property + def train_criterion_loss (self): + if self.train_criterion == "trades" : + return trades_loss + elif self.train_criterion == "mart" : + return mart_loss + elif self.train_criterion == "sat" : + return sat_loss + elif self.train_criterion == "hat" : + return hat_loss + @property + def test_criterion_loss (self) : + if self.test_criterion == "ce" : + return torch.nn.CrossEntropyLoss() + @property + def kd_criterion_loss (self) : + if self.kd_criterion =="ard" : + return ard_loss + elif self.kd_criterion == "adaad" : + return adaad_loss + elif self.kd_criterion == "rslad" : + return rslad_loss +class DistributedClassificationRunConfig(ClassificationRunConfig): + def __init__( + self, + n_epochs=150, + init_lr=0.05, + lr_schedule_type="cosine", + lr_schedule_param=None, + dataset="imagenet", + train_batch_size=64, + test_batch_size=64, + valid_size=None, + opt_type="sgd", + opt_param=None, + weight_decay=4e-5, + label_smoothing=0.1, + no_decay_keys=None, + mixup_alpha=None, + model_init="he_fout", + validation_frequency=1, + print_frequency=10, + n_worker=8, + resize_scale=0.08, + distort_color="tf", + image_size=224, + robust_mode = False, + epsilon = 0.031, + num_steps = 10, + step_size = 0.0078, + clip_min = 0, + clip_max = 1, + const_init = False, + beta = 6.0, + distance = "l_inf", + train_criterion = "trades", + test_criterion = "ce", + kd_criterion = 'rslad', + attack_type = "linf-pgd", + **kwargs + ): + super(DistributedClassificationRunConfig, self).__init__( + n_epochs, + init_lr, + lr_schedule_type, + lr_schedule_param, + dataset, + train_batch_size, + test_batch_size, + valid_size, + opt_type, + opt_param, + weight_decay, + label_smoothing, + no_decay_keys, + mixup_alpha, + model_init, + validation_frequency, + print_frequency, + n_worker, + resize_scale, + distort_color, + image_size, + robust_mode, + epsilon, + num_steps, + step_size, + clip_min, + clip_max, + const_init, + beta, + distance, + epsilon, + num_steps * 2, + step_size, + clip_min,clip_max, + const_init, + beta, + distance, + train_criterion, + test_criterion, + kd_criterion, + attack_type, + **kwargs + ) + + self._num_replicas = kwargs["num_replicas"] + self._rank = kwargs["rank"] + + @property + def data_provider(self): + if self.__dict__.get("_data_provider", None) is None: + if self.dataset == ImagenetDataProvider.name(): + DataProviderClass = ImagenetDataProvider + elif self.dataset == Cifar10DataProvider.name(): + DataProviderClass = Cifar10DataProvider + elif self.dataset == Cifar100DataProvider.name(): + DataProviderClass = Cifar100DataProvider + else: + raise NotImplementedError + if self.dataset == "imagenet": + self.__dict__["_data_provider"] = DataProviderClass( + train_batch_size=self.train_batch_size, + test_batch_size=self.test_batch_size, + valid_size=self.valid_size, + n_worker=self.n_worker, + resize_scale=self.resize_scale, + distort_color=self.distort_color, + image_size=self.image_size, + num_replicas=self._num_replicas, + rank=self._rank, + ) + else: + self.__dict__["_data_provider"] = DataProviderClass( + train_batch_size=self.train_batch_size, + test_batch_size=self.test_batch_size, + valid_size=self.valid_size, + n_worker=self.n_worker, + resize_scale=None, + distort_color=None, + image_size=self.image_size, + num_replicas=self._num_replicas, + rank=self._rank, + ) + return self.__dict__["_data_provider"] + @property + def train_criterion_loss (self): + if self.train_criterion == "trades" : + return trades_loss + elif self.train_criterion == "mart" : + return mart_loss + elif self.train_criterion == "sat" : + return sat_loss + elif self.train_criterion == "hat" : + return hat_loss + @property + def test_criterion_loss (self) : + if self.test_criterion == "ce" : + return torch.nn.CrossEntropyLoss() + @property + def kd_criterion_loss (self) : + if self.kd_criterion =="ard" : + return ard_loss + elif self.kd_criterion == "adaad" : + return adaad_loss + elif self.kd_criterion == "rslad" : + return rslad_loss + + \ No newline at end of file diff --git a/proard/classification/run_manager/run_manager.py b/proard/classification/run_manager/run_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..bc3f999afe900200e88ab2400a20fdf06905aa00 --- /dev/null +++ b/proard/classification/run_manager/run_manager.py @@ -0,0 +1,484 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import random +import time +import json +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +from tqdm import tqdm +from attacks.utils import ctx_noparamgrad_and_eval +from robust_loss.rslad import rslad_inner_loss,kl_loss +from robust_loss.trades import trades_loss +from attacks import create_attack +from proard.utils import ( + get_net_info, + cross_entropy_loss_with_soft_target, + cross_entropy_with_label_smoothing, +) +from proard.utils import ( + AverageMeter, + accuracy, + write_log, + mix_images, + mix_labels, + init_models, +) +from proard.utils import MyRandomResizedCrop + +__all__ = ["RunManager"] + + +class RunManager: + def __init__( + self, path, net, run_config, init=True, measure_latency=None, no_gpu=False + ): + self.path = path + self.net = net + self.run_config = run_config + + self.best_acc = 0 + self.best_robustness = 0 + self.start_epoch = 0 + + os.makedirs(self.path, exist_ok=True) + + # move network to GPU if available + if torch.cuda.is_available() and (not no_gpu): + self.device = torch.device("cuda") + self.net = self.net.to(self.device) + cudnn.benchmark = True + else: + self.device = torch.device("cpu") + # initialize model (default) + if init: + init_models(net,run_config.model_init) + + # net info + net_info = get_net_info( + self.net, self.run_config.data_provider.data_shape, measure_latency, True + ) + with open("%s/net_info.txt" % self.path, "w") as fout: + fout.write(json.dumps(net_info, indent=4) + "\n") + # noinspection PyBroadException + try: + fout.write(self.network.module_str + "\n") + except Exception: + pass + fout.write("%s\n" % self.run_config.data_provider.train.dataset.transform) + fout.write("%s\n" % self.run_config.data_provider.test.dataset.transform) + fout.write("%s\n" % self.network) + + self.train_criterion = self.run_config.train_criterion_loss + self.test_criterion = self.run_config.test_criterion_loss + self.kd_criterion = self.run_config.kd_criterion_loss + + # optimizer + if self.run_config.no_decay_keys: + keys = self.run_config.no_decay_keys.split("#") + net_params = [ + self.network.get_parameters( + keys, mode="exclude" + ), # parameters with weight decay + self.network.get_parameters( + keys, mode="include" + ), # parameters without weight decay + ] + else: + # noinspection PyBroadException + try: + net_params = self.network.weight_parameters() + except Exception: + net_params = [] + for param in self.network.parameters(): + if param.requires_grad: + net_params.append(param) + self.optimizer = self.run_config.build_optimizer(net_params) + + self.net = torch.nn.DataParallel(self.net) + + """ save path and log path """ + + @property + def save_path(self): + if self.__dict__.get("_save_path", None) is None: + save_path = os.path.join(self.path, "checkpoint") + os.makedirs(save_path, exist_ok=True) + self.__dict__["_save_path"] = save_path + return self.__dict__["_save_path"] + + @property + def logs_path(self): + if self.__dict__.get("_logs_path", None) is None: + logs_path = os.path.join(self.path, "logs") + os.makedirs(logs_path, exist_ok=True) + self.__dict__["_logs_path"] = logs_path + return self.__dict__["_logs_path"] + + @property + def network(self): + return self.net.module if isinstance(self.net, nn.DataParallel) else self.net + + def write_log(self, log_str, prefix="valid", should_print=True, mode="a"): + write_log(self.logs_path, log_str, prefix, should_print, mode) + + """ save and load models """ + + def save_model(self, checkpoint=None, is_best=False, model_name=None): + if checkpoint is None: + checkpoint = {"state_dict": self.network.state_dict()} + + if model_name is None: + model_name = "checkpoint.pth.tar" + + checkpoint[ + "dataset" + ] = self.run_config.dataset # add `dataset` info to the checkpoint + latest_fname = os.path.join(self.save_path, "latest.txt") + model_path = os.path.join(self.save_path, model_name) + with open(latest_fname, "w") as fout: + fout.write(model_path + "\n") + torch.save(checkpoint, model_path) + + if is_best: + best_path = os.path.join(self.save_path, "model_best.pth.tar") + torch.save({"state_dict": checkpoint["state_dict"]}, best_path) + + def load_model(self, model_fname=None): + latest_fname = os.path.join(self.save_path, "latest.txt") + if model_fname is None and os.path.exists(latest_fname): + with open(latest_fname, "r") as fin: + model_fname = fin.readline() + if model_fname[-1] == "\n": + model_fname = model_fname[:-1] + # noinspection PyBroadException + try: + if model_fname is None or not os.path.exists(model_fname): + model_fname = "%s/checkpoint.pth.tar" % self.save_path + with open(latest_fname, "w") as fout: + fout.write(model_fname + "\n") + print("=> loading checkpoint '{}'".format(model_fname)) + checkpoint = torch.load(model_fname, map_location="cpu") + except Exception: + print("fail to load checkpoint from %s" % self.save_path) + return {} + + self.network.load_state_dict(checkpoint["state_dict"]) + if "epoch" in checkpoint: + self.start_epoch = checkpoint["epoch"] + 1 + if "best_acc" in checkpoint: + self.best_acc = checkpoint["best_acc"] + if "optimizer" in checkpoint: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + + print("=> loaded checkpoint '{}'".format(model_fname)) + return checkpoint + + def save_config(self, extra_run_config=None, extra_net_config=None): + """dump run_config and net_config to the model_folder""" + run_save_path = os.path.join(self.path, "run.config") + if not os.path.isfile(run_save_path): + run_config = self.run_config.config + if extra_run_config is not None: + run_config.update(extra_run_config) + json.dump(run_config, open(run_save_path, "w"), indent=4) + print("Run configs dump to %s" % run_save_path) + + try: + net_save_path = os.path.join(self.path, "net.config") + net_config = self.network.config + if extra_net_config is not None: + net_config.update(extra_net_config) + json.dump(net_config, open(net_save_path, "w"), indent=4) + print("Network configs dump to %s" % net_save_path) + except Exception: + print("%s do not support net config" % type(self.network)) + + """ metric related """ + + def get_metric_dict(self): + return { + "top1": AverageMeter(), + "top5": AverageMeter(), + "robust1" :AverageMeter(), + "robust5" :AverageMeter(), + } + + def update_metric(self, metric_dict, output, output_adv, labels): + acc1, acc5 = accuracy(output, labels, topk=(1, 5)) + robust1,robust5 = accuracy(output_adv,labels,topk=(1,5)) + metric_dict["top1"].update(acc1[0].item(), output.size(0)) + metric_dict["top5"].update(acc5[0].item(), output.size(0)) + metric_dict["robust1"].update(robust1[0].item(), output.size(0)) + metric_dict["robust5"].update(robust5[0].item(), output.size(0)) + + + def get_metric_vals(self, metric_dict, return_dict=False): + if return_dict: + return {key: metric_dict[key].avg for key in metric_dict} + else: + return [metric_dict[key].avg for key in metric_dict] + + def get_metric_names(self): + return "top1", "top5" , "robust1" , "robust5" + + """ train and test """ + + def validate( + self, + epoch=0, + is_test=False, + run_str="", + net=None, + data_loader=None, + no_logs=False, + train_mode=False, + ): + if net is None: + net = self.net + if not isinstance(net, nn.DataParallel): + net = nn.DataParallel(net) + if data_loader is None: + data_loader = ( + self.run_config.test_loader if is_test else self.run_config.valid_loader + ) + + if train_mode: + net.train() + else: + net.eval() + if self.run_config.robust_mode: + eval_attack = create_attack(net, self.test_criterion.cuda(), self.run_config.attack_type,self.run_config.epsilon_test,self.run_config.num_steps_test, self.run_config.step_size_test) + losses = AverageMeter() + metric_dict = self.get_metric_dict() + + with tqdm( + total=len(data_loader), + desc="Validate Epoch #{} {}".format(epoch + 1, run_str), + disable=no_logs, + ) as t: + for i, (images, labels) in enumerate(data_loader): + images, labels = images.to(self.device), labels.to(self.device) + # compute output + output = net(images) + if self.run_config.robust_mode: + with ctx_noparamgrad_and_eval(net): + images_adv,_ = eval_attack.perturb(images, labels) + output_adv = net(images_adv) + loss = nn.CrossEntropyLoss()(output_adv,labels) + else: + output_adv = output + loss = nn.CrossEntropyLoss()(output,labels) + + # measure accuracy and record loss + self.update_metric(metric_dict, output, output_adv , labels) + + losses.update(loss.item(), images.size(0)) + t.set_postfix( + { + "loss": losses.avg, + **self.get_metric_vals(metric_dict, return_dict=True), + "img_size": images.size(2), + } + ) + t.update(1) + return losses.avg, self.get_metric_vals(metric_dict) + + def validate_all_resolution(self, epoch=0, is_test=False, net=None): + if net is None: + net = self.network + if isinstance(self.run_config.data_provider.image_size, list): + img_size_list, loss_list, top1_list, top5_list , robust1_list , robust5_list = [], [], [], [],[],[] + for img_size in self.run_config.data_provider.image_size: + img_size_list.append(img_size) + self.run_config.data_provider.assign_active_img_size(img_size) + self.reset_running_statistics(net=net) + loss, (top1, top5 , robust1,robust5) = self.validate(epoch, is_test, net=net) + loss_list.append(loss) + top1_list.append(top1) + top5_list.append(top5) + robust1_list.append(robust1) + robust5_list.append(robust5) + return img_size_list, loss_list, top1_list, top5_list ,robust1_list ,robust5_list + else: + loss, (top1, top5 , robust1 , robust5) = self.validate(epoch, is_test, net=net) + return ( + [self.run_config.data_provider.active_img_size], + [loss], + [top1], + [top5], + [robust1], + [robust5] + ) + + def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0): + # switch to train mode + self.net.train() + MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution + + nBatch = len(self.run_config.train_loader) + + losses = AverageMeter() + metric_dict = self.get_metric_dict() + data_time = AverageMeter() + + with tqdm( + total=nBatch, + desc="{} Train Epoch #{}".format(self.run_config.dataset, epoch + 1), + ) as t: + end = time.time() + for i, (images, labels) in enumerate(self.run_config.train_loader): + MyRandomResizedCrop.BATCH = i + data_time.update(time.time() - end) + if epoch < warmup_epochs: + new_lr = self.run_config.warmup_adjust_learning_rate( + self.optimizer, + warmup_epochs * nBatch, + nBatch, + epoch, + i, + warmup_lr, + ) + else: + new_lr = self.run_config.adjust_learning_rate( + self.optimizer, epoch - warmup_epochs, i, nBatch + ) + + images, labels = images.to(self.device), labels.to(self.device) + target = labels + if isinstance(self.run_config.mixup_alpha, float): + # transform data + lam = random.betavariate( + self.run_config.mixup_alpha, self.run_config.mixup_alpha + ) + images = mix_images(images, lam) + labels = mix_labels( + labels, + lam, + self.run_config.data_provider.n_classes, + self.run_config.label_smoothing, + ) + + # soft target + if args.teacher_model is not None: + args.teacher_model.train() + with torch.no_grad(): + soft_logits = args.teacher_model(images).detach() + soft_label = F.softmax(soft_logits, dim=1) + + # compute output + output = self.net(images) + + if args.teacher_model is None: + if self.run_config.robust_mode: + loss = self.train_criterion(self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train,self.run_config.distance_train) + loss_type = self.run_config.train_criterion + else: + loss = torch.nn.CrossEntropyLoss(output,labels) + loss_type = 'ce' + + else: + if self.run_config.robust_mode: + loss = self.kd_criterion(args.teacher_model,self.net,images,labels,self.optimizer,self.run_config.step_size_train,self.run_config.epsilon_train,self.run_config.num_steps_train,self.run_config.beta_train) + loss_type = self.run_config.train_criterion + else: + if args.kd_type == "ce": + kd_loss = cross_entropy_loss_with_soft_target( + output, soft_label + ) + else: + kd_loss = F.mse_loss(output, soft_logits) + loss = args.kd_ratio * kd_loss + loss + loss_type = "%.1fkd+ce" % args.kd_ratio + + # compute gradient and do SGD step + self.net.zero_grad() # or self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # measure accuracy and record loss + losses.update(loss.item(), images.size(0)) + self.update_metric(metric_dict, output, output ,target) + + t.set_postfix( + { + "loss": losses.avg, + **self.get_metric_vals(metric_dict, return_dict=True), + "img_size": images.size(2), + "lr": new_lr, + "loss_type": loss_type, + "data_time": data_time.avg, + } + ) + t.update(1) + end = time.time() + return losses.avg, self.get_metric_vals(metric_dict) + + def train(self, args, warmup_epoch=0, warmup_lr=0): + for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch): + train_loss, (train_top1, train_top5 , train_robust1 , train_robust5) = self.train_one_epoch( + args, epoch, warmup_epoch, warmup_lr + ) + + if (epoch + 1) % self.run_config.validation_frequency == 0: + img_size, val_loss, val_acc, val_acc5 ,val_robust, val_robust5 = self.validate_all_resolution( + epoch=epoch, is_test=False + ) + + is_best = np.mean(val_acc) > self.best_acc + is_best_robust = np.mean(val_robust) > self.best_robustness + self.best_acc = max(self.best_acc, np.mean(val_acc)) + self.best_robustness = max(self.best_robustness, np.mean(val_robust)) + val_log = "Valid [{0}/{1}]\tloss {2:.3f} \t{7} {3:.3f} ({5:.3f}) \t{8} {4:.3f} ({6:.3f})".format( + epoch + 1 - warmup_epoch, + self.run_config.n_epochs, + np.mean(val_loss), + np.mean(val_acc), + np.mean(val_robust), + self.best_acc, + self.best_robustness, + self.get_metric_names()[0], + self.get_metric_names()[2], + ) + val_log += "\t{2} {0:.3f} \tTrain {1} {top1:.3f}\t {3} {robust:.3f} \t loss {train_loss:.3f}\t".format( + np.mean(val_acc5), + *self.get_metric_names(), + top1=train_top1, + robust = train_robust1, + train_loss=train_loss + ) + for i_s, v_a in zip(img_size, val_acc): + val_log += "(%d, %.3f), " % (i_s, v_a) + self.write_log(val_log, prefix="valid", should_print=False) + else: + is_best = False + is_best_robust = False + + self.save_model( + { + "epoch": epoch, + "best_acc": self.best_acc, + "optimizer": self.optimizer.state_dict(), + "state_dict": self.network.state_dict(), + }, + is_best=is_best, + ) + + def reset_running_statistics( + self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None + ): + from proard.classification.elastic_nn.utils import set_running_statistics + + if net is None: + net = self.network + if data_loader is None: + data_loader = self.run_config.random_sub_train_loader( + subset_size, subset_batch_size + ) + set_running_statistics(net, data_loader) diff --git a/proard/model_zoo.py b/proard/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..681838052846a2e05939053ab21847311704a605 --- /dev/null +++ b/proard/model_zoo.py @@ -0,0 +1,162 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import json +import torch +import gdown + +from proard.classification.networks import get_net_by_name, ResNet50 +from proard.classification.elastic_nn.networks import ( + DYNResNets,DYNMobileNetV3,DYNProxylessNASNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar +) +from proard.classification.networks import (WideResNet,ResNet50_Cifar,ResNet50,MobileNetV3_Cifar,MobileNetV3Large_Cifar,MobileNetV3Large,ProxylessNASNets_Cifar,ProxylessNASNets,MobileNetV2_Cifar,MobileNetV2) +__all__ = [ + "DYN_net", +] + + + +def DYN_net(net_id, robust_mode, dataset,train_criterion, pretrained=True,run_config=None,WPS=False,base=False): + if net_id == "ResNet50": + if not base: + if dataset == "cifar10" or dataset == "cifar100": + net = DYNResNets_Cifar(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + depth_list=[0, 1, 2], + expand_ratio_list=[0.2, 0.25, 0.35], + width_mult_list=[0.65, 0.8, 1.0], + ) + else: + net = DYNResNets(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + depth_list=[0, 1, 2], + expand_ratio_list=[0.2, 0.25, 0.35], + width_mult_list=[0.65, 0.8, 1.0], + ) + else: + if dataset == "cifar10" or dataset == "cifar100": + net = ResNet50_Cifar(n_classes=run_config.data_provider.n_classes) + else: + net = ResNet50(n_classes=run_config.data_provider.n_classes) + + elif net_id == "MBV3": + if not base: + if dataset == "cifar10" or dataset == "cifar100": + net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + width_mult=1.0, + ks_list=[3, 5, 7], + expand_ratio_list=[3, 4, 6], + depth_list=[2, 3, 4], + ) + else: + net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + width_mult=1.0, + ks_list=[3, 5, 7], + expand_ratio_list=[3, 4, 6], + depth_list=[2, 3, 4], + ) + else: + if dataset == "cifar10" or dataset == "cifar100": + net = MobileNetV3Large_Cifar(n_classes=run_config.data_provider.n_classes) + else: + net = MobileNetV3Large(n_classes=run_config.data_provider.n_classes) + elif net_id == "ProxylessNASNet": + if not base: + if dataset == "cifar10" or dataset == "cifar100": + net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + width_mult=1.0, + ks_list=[3, 5, 7], + expand_ratio_list=[3, 4, 6], + depth_list=[2, 3, 4], + ) + else: + net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + width_mult=1.0, + ks_list=[3, 5, 7], + expand_ratio_list=[3, 4, 6], + depth_list=[2, 3, 4], + ) + else: + if dataset == "cifar10" or dataset == "cifar100": + net = ProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes) + else: + net = ProxylessNASNets(n_classes=run_config.data_provider.n_classes) + elif net_id == "MBV2": + if not base: + if dataset == "cifar10" or dataset == "cifar100": + net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + base_stage_width="google", + width_mult=1.0, + ks_list=[3, 5, 7], + expand_ratio_list=[3, 4, 6], + depth_list=[2, 3, 4], + ) + else: + net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes, + dropout_rate=0, + base_stage_width="google", + width_mult=1.0, + ks_list=[3, 5, 7], + expand_ratio_list=[3, 4, 6], + depth_list=[2, 3, 4], + ) + else: + if dataset == "cifar10" or dataset == "cifar100": + net = MobileNetV2_Cifar(n_classes=run_config.data_provider.n_classes) + else: + net = MobileNetV2(n_classes=run_config.data_provider.n_classes) + elif net_id == "WideResNet": + if dataset == "cifar10" or dataset == "cifar100": + net = WideResNet(num_classes=run_config.data_provider.n_classes) + else: + raise ValueError("Not supported: %s" % net_id) + + else: + raise ValueError("Not supported: %s" % net_id) + + if pretrained and not WPS and not base: + if net_id == "ResNet50": + if robust_mode: + pt_path = "exp/robust/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + else: + pt_path = "exp/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + else: + if robust_mode: + pt_path = "exp/robust/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + + else: + pt_path = "exp/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + elif pretrained and WPS and not base: + if net_id == "ResNet50": + if robust_mode: + pt_path = "exp/robust/WPS/"+ dataset + "/" + net_id + '/' + train_criterion +"/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + else: + pt_path = "exp/WPS/"+ dataset + "/" + net_id + '/' + train_criterion + "/width_depth2width_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + else: + if robust_mode: + pt_path = "exp/robust/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + + else: + pt_path = "exp/WPS/"+ dataset + '/' + net_id + '/' + train_criterion +"/kernel_depth2kernel_depth_width/phase2" + "/checkpoint/model_best.pth.tar" + else: + if not base: + pt_path = "exp/robust/teacher/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar" + else: + pt_path = "exp/robust/base/"+ dataset + '/' + net_id + '/' + train_criterion + "/checkpoint/model_best.pth.tar" + print(pt_path) + init = torch.load(pt_path, map_location="cuda")["state_dict"] + # from collections import OrderedDict + # new_state_dict = OrderedDict() + # for k, v in init.items(): + # name = k[7:] # remove `module.` + # new_state_dict[name] = v + net.load_state_dict(init) + return net + + diff --git a/proard/nas/__init__.py b/proard/nas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/proard/nas/accuracy_predictor/__init__.py b/proard/nas/accuracy_predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3efd8f92e151b954501eeab137ffd5ea11bcee --- /dev/null +++ b/proard/nas/accuracy_predictor/__init__.py @@ -0,0 +1,11 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .acc_dataset import * +from .acc_predictor import * +from .arch_encoder import * +from .rob_dataset import * +from .rob_predictor import * +from .acc_rob_dataset import * +from .acc_rob_predictor import * diff --git a/proard/nas/accuracy_predictor/acc_dataset.py b/proard/nas/accuracy_predictor/acc_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..52728b1f1ad0981bf04bf07a87db41bc18b7f809 --- /dev/null +++ b/proard/nas/accuracy_predictor/acc_dataset.py @@ -0,0 +1,213 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import json +import numpy as np +from tqdm import tqdm +import torch +import torch.utils.data + +from proard.utils import list_mean + +__all__ = ["net_setting2id", "net_id2setting", "AccuracyDataset"] + + +def net_setting2id(net_setting): + return json.dumps(net_setting) + + +def net_id2setting(net_id): + return json.loads(net_id) + + +class RegDataset(torch.utils.data.Dataset): + def __init__(self, inputs, targets): + super(RegDataset, self).__init__() + self.inputs = inputs + self.targets = targets + + def __getitem__(self, index): + return self.inputs[index], self.targets[index] + + def __len__(self): + return self.inputs.size(0) + + +class AccuracyDataset: + def __init__(self, path): + self.path = path + os.makedirs(self.path, exist_ok=True) + + @property + def net_id_path(self): + return os.path.join(self.path, "net_id.dict") + + @property + def acc_src_folder(self): + return os.path.join(self.path, "src") + @property + def acc_dict_path(self): + return os.path.join(self.path, "src/acc.dict") + + + # TODO: support parallel building + def build_acc_dataset( + self, run_manager, dyn_network, n_arch=2000, image_size_list=None + ): + # load net_id_list, random sample if not exist + if os.path.isfile(self.net_id_path): + net_id_list = json.load(open(self.net_id_path)) + else: + net_id_list = set() + while len(net_id_list) < n_arch: + net_setting = dyn_network.sample_active_subnet() + net_id = net_setting2id(net_setting) + net_id_list.add(net_id) + net_id_list = list(net_id_list) + net_id_list.sort() + json.dump(net_id_list, open(self.net_id_path, "w"), indent=4) + + image_size_list = ( + [128, 160, 192, 224] if image_size_list is None else image_size_list + ) + print(image_size_list) + with tqdm( + total=len(net_id_list) * len(image_size_list), desc="Building Acc Dataset" + ) as t: + for image_size in image_size_list: + # load val dataset into memory + val_dataset = [] + run_manager.run_config.data_provider.assign_active_img_size(image_size) + for images, labels in run_manager.run_config.valid_loader: + val_dataset.append((images, labels)) + # save path + os.makedirs(self.acc_src_folder, exist_ok=True) + acc_save_path = os.path.join( + self.acc_src_folder, "%d.dict" % image_size + ) + acc_dict = {} + # load existing acc dict + if os.path.isfile(acc_save_path): + existing_acc_dict = json.load(open(acc_save_path, "r")) + else: + existing_acc_dict = {} + for net_id in net_id_list: + net_setting = net_id2setting(net_id) + key = net_setting2id({**net_setting, "image_size": image_size}) + if key in existing_acc_dict: + acc_dict[key] = existing_acc_dict[key] + t.set_postfix( + { + "net_id": net_id, + "image_size": image_size, + "info_val": acc_dict[key], + "status": "loading", + } + ) + t.update() + continue + dyn_network.set_active_subnet(**net_setting) + run_manager.reset_running_statistics(dyn_network) + net_setting_str = ",".join( + [ + "%s_%s" + % ( + key, + "%.1f" % list_mean(val) + if isinstance(val, list) + else val, + ) + for key, val in net_setting.items() + ] + ) + loss, (top1, top5,robust1,robust5) = run_manager.validate( + run_str=net_setting_str, + net=dyn_network, + data_loader=val_dataset, + no_logs=True, + ) + info_val = top1 + t.set_postfix( + { + "net_id": net_id, + "image_size": image_size, + "info_val": info_val, + } + ) + t.update() + + acc_dict.update({key: info_val}) + json.dump(acc_dict, open(acc_save_path, "w"), indent=4) + + + def merge_acc_dataset(self, image_size_list=None): + # load existing data + merged_acc_dict = {} + for fname in os.listdir(self.acc_src_folder): + if ".dict" not in fname: + continue + image_size = int(fname.split(".dict")[0]) + if image_size_list is not None and image_size not in image_size_list: + print("Skip ", fname) + continue + full_path = os.path.join(self.acc_src_folder, fname) + partial_acc_dict = json.load(open(full_path)) + merged_acc_dict.update(partial_acc_dict) + print("loaded %s" % full_path) + json.dump(merged_acc_dict, open(self.acc_dict_path, "w"), indent=4) + return merged_acc_dict + + def build_acc_data_loader( + self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16 + ): + # load data + acc_dict = json.load(open(self.acc_dict_path)) + X_all = [] + Y_all = [] + + with tqdm(total=len(acc_dict), desc="Loading data") as t: + for k, v in acc_dict.items(): + dic = json.loads(k) + X_all.append(arch_encoder.arch2feature(dic)) + Y_all.append(v / 100.0) # range: 0 - 1 + t.update() + base_acc = np.mean(Y_all) + # convert to torch tensor + X_all = torch.tensor(X_all, dtype=torch.float) + Y_all = torch.tensor(Y_all) + + + # random shuffle + shuffle_idx = torch.randperm(len(X_all)) + X_all = X_all[shuffle_idx] + Y_all = Y_all[shuffle_idx] + # split data + idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample + val_idx = X_all.size(0) // 5 * 4 + X_train, Y_train = X_all[:idx], Y_all[:idx] + X_test, Y_test = X_all[val_idx:], Y_all[val_idx:] + print("Train Size: %d," % len(X_train), "Valid Size: %d" % len(X_test)) + + # build data loader + train_dataset = RegDataset(X_train, Y_train) + val_dataset = RegDataset(X_test, Y_test) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=False, + num_workers=n_workers, + ) + valid_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=False, + num_workers=n_workers, + ) + + return train_loader, valid_loader, base_acc + + diff --git a/proard/nas/accuracy_predictor/acc_predictor.py b/proard/nas/accuracy_predictor/acc_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..59d2f804e2a91680dd55fe19253125b4660b2258 --- /dev/null +++ b/proard/nas/accuracy_predictor/acc_predictor.py @@ -0,0 +1,68 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import numpy as np +import torch +import torch.nn as nn + +__all__ = ["AccuracyPredictor"] + + +class AccuracyPredictor(nn.Module): + def __init__( + self, + arch_encoder, + hidden_size=400, + n_layers=3, + checkpoint_path=None, + device="cuda:0", + base_acc_val = None + ): + super(AccuracyPredictor, self).__init__() + self.arch_encoder = arch_encoder + self.hidden_size = hidden_size + self.n_layers = n_layers + self.device = device + self.base_acc_val = base_acc_val + # build layers + layers = [] + for i in range(self.n_layers): + layers.append( + nn.Sequential( + nn.Linear( + self.arch_encoder.n_dim if i == 0 else self.hidden_size, + self.hidden_size, + ), + nn.ReLU(inplace=True), + ) + ) + layers.append(nn.Linear(self.hidden_size, 1, bias=False)) + self.layers = nn.Sequential(*layers) + if self.base_acc_val!=None : + self.base_acc = nn.Parameter( + torch.tensor(self.base_acc_val, device=self.device), requires_grad=False + ) + else: + self.base_acc = nn.Parameter( + torch.zeros(1, device=self.device), requires_grad=False + ) + + if checkpoint_path is not None and os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + self.load_state_dict(checkpoint) + print("Loaded checkpoint from %s" % checkpoint_path) + + self.layers = self.layers.to(self.device) + + def forward(self, x): + y = self.layers(x).squeeze() + return y + self.base_acc + + def predict_acc(self, arch_dict_list): + X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list] + X = torch.tensor(np.array(X)).float().to(self.device) + return self.forward(X) diff --git a/proard/nas/accuracy_predictor/acc_rob_dataset.py b/proard/nas/accuracy_predictor/acc_rob_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7da4dc77e356d32f3d4cba62a854cb9d9faefb56 --- /dev/null +++ b/proard/nas/accuracy_predictor/acc_rob_dataset.py @@ -0,0 +1,219 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import json +import numpy as np +from tqdm import tqdm +import torch +import torch.utils.data + +from proard.utils import list_mean + +__all__ = ["net_setting2id", "net_id2setting", "AccuracyRobustnessDataset"] + + +def net_setting2id(net_setting): + return json.dumps(net_setting) + + +def net_id2setting(net_id): + return json.loads(net_id) + + +class TwoRegDataset(torch.utils.data.Dataset): + def __init__(self, inputs, targets_acc , targets_rob ): + super(TwoRegDataset, self).__init__() + self.inputs = inputs + self.targets_acc = targets_acc + self.targets_rob = targets_rob + + def __getitem__(self, index): + return self.inputs[index], self.targets_acc[index] , self.targets_rob[index] + + def __len__(self): + return self.inputs.size(0) + + +class AccuracyRobustnessDataset: + def __init__(self, path): + self.path = path + os.makedirs(self.path, exist_ok=True) + + @property + def net_id_path(self): + return os.path.join(self.path, "net_id.dict") + + @property + def acc_rob_src_folder(self): + return os.path.join(self.path, "src") + @property + def acc_rob_dict_path(self): + return os.path.join(self.path, "src/acc_robust.dict") + + + # TODO: support parallel building + def build_acc_rob_dataset( + self, run_manager, dyn_network, n_arch=2000, image_size_list=None + ): + # load net_id_list, random sample if not exist + if os.path.isfile(self.net_id_path): + net_id_list = json.load(open(self.net_id_path)) + else: + net_id_list = set() + while len(net_id_list) < n_arch: + net_setting = dyn_network.sample_active_subnet() + net_id = net_setting2id(net_setting) + net_id_list.add(net_id) + net_id_list = list(net_id_list) + net_id_list.sort() + json.dump(net_id_list, open(self.net_id_path, "w"), indent=4) + + image_size_list = ( + [128, 160, 192, 224] if image_size_list is None else image_size_list + ) + print(image_size_list) + with tqdm( + total=len(net_id_list) * len(image_size_list), desc="Building Acc Dataset" + ) as t: + for image_size in image_size_list: + # load val dataset into memory + val_dataset = [] + run_manager.run_config.data_provider.assign_active_img_size(image_size) + for images, labels in run_manager.run_config.valid_loader: + val_dataset.append((images, labels)) + # save path + os.makedirs(self.acc_rob_src_folder, exist_ok=True) + acc_rob_save_path = os.path.join( + self.acc_rob_src_folder, "%d.dict" % image_size + ) + acc_rob_dict = {} + # load existing acc dict + if os.path.isfile(acc_rob_save_path): + existing_acc_rob_dict = json.load(open(acc_rob_save_path, "r")) + else: + existing_acc_rob_dict = {} + for net_id in net_id_list: + net_setting = net_id2setting(net_id) + key = net_setting2id({**net_setting, "image_size": image_size}) + if key in existing_acc_rob_dict: + acc_rob_dict[key] = existing_acc_rob_dict[key] + t.set_postfix( + { + "net_id": net_id, + "image_size": image_size, + "info_val": acc_rob_dict[key], + "status": "loading", + } + ) + t.update() + continue + dyn_network.set_active_subnet(**net_setting) + run_manager.reset_running_statistics(dyn_network) + net_setting_str = ",".join( + [ + "%s_%s" + % ( + key, + "%.1f" % list_mean(val) + if isinstance(val, list) + else val, + ) + for key, val in net_setting.items() + ] + ) + loss, (top1, top5,robust1,robust5) = run_manager.validate( + run_str=net_setting_str, + net=dyn_network, + data_loader=val_dataset, + no_logs=True, + ) + info_val = [top1,robust1] + t.set_postfix( + { + "net_id": net_id, + "image_size": image_size, + "info_val": info_val, + } + ) + t.update() + + acc_rob_dict.update({key: info_val}) + json.dump(acc_rob_dict, open(acc_rob_save_path, "w"), indent=4) + + + def merge_acc_dataset(self, image_size_list=None): + # load existing data + merged_acc_rob_dict = {} + for fname in os.listdir(self.acc_rob_src_folder): + if ".dict" not in fname: + continue + image_size = int(fname.split(".dict")[0]) + if image_size_list is not None and image_size not in image_size_list: + print("Skip ", fname) + continue + full_path = os.path.join(self.acc_rob_src_folder, fname) + partial_acc_rob_dict = json.load(open(full_path)) + merged_acc_rob_dict.update(partial_acc_rob_dict) + print("loaded %s" % full_path) + json.dump(merged_acc_rob_dict, open(self.acc_rob_dict_path, "w"), indent=4) + return merged_acc_rob_dict + + def build_acc_data_loader( + self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16 + ): + # load data + acc_rob_dict = json.load(open(self.acc_rob_dict_path)) + X_all = [] + Y_acc_all = [] + Y_rob_all = [] + + with tqdm(total=len(acc_rob_dict), desc="Loading data") as t: + for k, v in acc_rob_dict.items(): + dic = json.loads(k) + X_all.append(arch_encoder.arch2feature(dic)) + Y_acc_all.append(v[0] / 100.0) # range: 0 - 1 + Y_rob_all.append(v[1] / 100.0) + t.update() + base_acc = np.mean(Y_acc_all) + base_rob = np.mean(Y_rob_all) + # convert to torch tensor + X_all = torch.tensor(X_all, dtype=torch.float) + Y_acc_all = torch.tensor(Y_acc_all) + Y_rob_all = torch.tensor(Y_rob_all) + + + # random shuffle + shuffle_idx = torch.randperm(len(X_all)) + X_all = X_all[shuffle_idx] + Y_acc_all = Y_acc_all[shuffle_idx] + Y_rob_all = Y_rob_all[shuffle_idx] + # split data + idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample + val_idx = X_all.size(0) // 5 * 4 + X_train, Y_acc_train, Y_rob_train = X_all[:idx], Y_acc_all[:idx], Y_rob_all[:idx] + X_test, Y_acc_test , Y_rob_test = X_all[val_idx:], Y_acc_all[val_idx:] , Y_rob_all[val_idx:] + print("Train Size: %d," % len(X_train), "Valid Size: %d" % len(X_test)) + + # build data loader + train_dataset = TwoRegDataset(X_train, Y_acc_train , Y_rob_train) + val_dataset = TwoRegDataset(X_test, Y_acc_test ,Y_rob_test ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=False, + num_workers=n_workers, + ) + valid_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=False, + num_workers=n_workers, + ) + + return train_loader, valid_loader, base_acc, base_rob + + diff --git a/proard/nas/accuracy_predictor/acc_rob_predictor.py b/proard/nas/accuracy_predictor/acc_rob_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..d67745ae99b8ba05c0e70c6a51ce00e24fe9da75 --- /dev/null +++ b/proard/nas/accuracy_predictor/acc_rob_predictor.py @@ -0,0 +1,77 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import numpy as np +import torch +import torch.nn as nn + +__all__ = ["Accuracy_Robustness_Predictor"] + + +class Accuracy_Robustness_Predictor(nn.Module): + def __init__( + self, + arch_encoder, + hidden_size=400, + n_layers=6, + checkpoint_path=None, + device="cuda:0", + base_acc_val = None, + base_rob_val = None + ): + super(Accuracy_Robustness_Predictor, self).__init__() + self.arch_encoder = arch_encoder + self.hidden_size = hidden_size + self.n_layers = n_layers + self.device = device + self.base_acc_val = base_acc_val + self.base_rob_val = base_rob_val + # build layers + layers = [] + for i in range(self.n_layers): + layers.append( + nn.Sequential( + nn.Linear( + self.arch_encoder.n_dim if i == 0 else self.hidden_size, + self.hidden_size, + ), + nn.ReLU(inplace=True), + ) + ) + layers.append(nn.Linear(self.hidden_size, 2, bias=False)) + self.layers = nn.Sequential(*layers) + if self.base_acc_val!=None : + self.base_acc = nn.Parameter( + torch.tensor(self.base_acc_val, device=self.device), requires_grad=False + ) + else: + self.base_acc = nn.Parameter( + torch.zeros(1, device=self.device), requires_grad=False + ) + if self.base_rob_val!=None : + self.base_rob = nn.Parameter( + torch.tensor(self.base_rob_val, device=self.device), requires_grad=False + ) + else: + self.base_rob = nn.Parameter( + torch.zeros(1, device=self.device), requires_grad=False + ) + if checkpoint_path is not None and os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + self.load_state_dict(checkpoint) + print("Loaded checkpoint from %s" % checkpoint_path) + + self.layers = self.layers.to(self.device) + + def forward(self, x): + y = self.layers(x).squeeze() + return y + self.base_acc + + def predict_acc_rob(self, arch_dict_list): + X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list] + X = torch.tensor(np.array(X)).float().to(self.device) + return self.forward(X) diff --git a/proard/nas/accuracy_predictor/arch_encoder.py b/proard/nas/accuracy_predictor/arch_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd7cb2acf981bcf898adb20ee2dfcad2c5b8c42 --- /dev/null +++ b/proard/nas/accuracy_predictor/arch_encoder.py @@ -0,0 +1,372 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + + +import random +import numpy as np +from proard.classification.networks import ResNets + +__all__ = ["MobileNetArchEncoder", "ResNetArchEncoder"] + + +class MobileNetArchEncoder: + SPACE_TYPE = "mbv3" + + def __init__( + self, + image_size_list=None, + ks_list=None, + expand_list=None, + depth_list=None, + n_stage=None, + ): + self.image_size_list = [224] if image_size_list is None else image_size_list + self.ks_list = [3, 5, 7] if ks_list is None else ks_list + self.expand_list = ( + [3, 4, 6] + if expand_list is None + else [int(expand) for expand in expand_list] + ) + self.depth_list = [2, 3, 4] if depth_list is None else depth_list + if n_stage is not None: + self.n_stage = n_stage + elif self.SPACE_TYPE == "mbv2": + self.n_stage = 6 + elif self.SPACE_TYPE == "mbv3": + self.n_stage = 5 + else: + raise NotImplementedError + + # build info dict + self.n_dim = 0 + self.r_info = dict(id2val={}, val2id={}, L=[], R=[]) + self._build_info_dict(target="r") + self.k_info = dict(id2val=[], val2id=[], L=[], R=[]) + self.e_info = dict(id2val=[], val2id=[], L=[], R=[]) + self._build_info_dict(target="k") + self._build_info_dict(target="e") + + @property + def max_n_blocks(self): + if self.SPACE_TYPE == "mbv3": + return self.n_stage * max(self.depth_list) + elif self.SPACE_TYPE == "mbv2": + return (self.n_stage - 1) * max(self.depth_list) + 1 + else: + raise NotImplementedError + + def _build_info_dict(self, target): + if target == "r": + target_dict = self.r_info + target_dict["L"].append(self.n_dim) + for img_size in self.image_size_list: + target_dict["val2id"][img_size] = self.n_dim + target_dict["id2val"][self.n_dim] = img_size + self.n_dim += 1 + target_dict["R"].append(self.n_dim) + else: + if target == "k": + target_dict = self.k_info + choices = self.ks_list + elif target == "e": + target_dict = self.e_info + choices = self.expand_list + else: + raise NotImplementedError + for i in range(self.max_n_blocks): + target_dict["val2id"].append({}) + target_dict["id2val"].append({}) + target_dict["L"].append(self.n_dim) + for k in choices: + target_dict["val2id"][i][k] = self.n_dim + target_dict["id2val"][i][self.n_dim] = k + self.n_dim += 1 + target_dict["R"].append(self.n_dim) + + def arch2feature(self, arch_dict): + ks, e, d, r = ( + arch_dict["ks"], + arch_dict["e"], + arch_dict["d"], + arch_dict["image_size"], + ) + feature = np.zeros(self.n_dim) + for i in range(self.max_n_blocks): + nowd = i % max(self.depth_list) + stg = i // max(self.depth_list) + if nowd < d[stg]: + feature[self.k_info["val2id"][i][ks[i]]] = 1 + feature[self.e_info["val2id"][i][e[i]]] = 1 + feature[self.r_info["val2id"][r[0]]] = 1 + return feature + + def feature2arch(self, feature): + img_sz = self.r_info["id2val"][ + int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]])) + + self.r_info["L"][0] + ] + assert img_sz in self.image_size_list + arch_dict = {"ks": [], "e": [], "d": [], "image_size": img_sz} + + d = 0 + for i in range(self.max_n_blocks): + skip = True + for j in range(self.k_info["L"][i], self.k_info["R"][i]): + if feature[j] == 1: + arch_dict["ks"].append(self.k_info["id2val"][i][j]) + skip = False + break + + for j in range(self.e_info["L"][i], self.e_info["R"][i]): + if feature[j] == 1: + arch_dict["e"].append(self.e_info["id2val"][i][j]) + assert not skip + skip = False + break + + if skip: + arch_dict["e"].append(0) + arch_dict["ks"].append(0) + else: + d += 1 + + if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks: + arch_dict["d"].append(d) + d = 0 + return arch_dict + + def random_sample_arch(self): + return { + "ks": random.choices(self.ks_list, k=self.max_n_blocks), + "e": random.choices(self.expand_list, k=self.max_n_blocks), + "d": random.choices(self.depth_list, k=self.n_stage), + "image_size": [random.choice(self.image_size_list)], + } + + def mutate_resolution(self, arch_dict, mutate_prob): + if random.random() < mutate_prob: + arch_dict["image_size"] = random.choice(self.image_size_list) + return arch_dict + + def mutate_arch(self, arch_dict, mutate_prob): + for i in range(self.max_n_blocks): + if random.random() < mutate_prob: + arch_dict["ks"][i] = random.choice(self.ks_list) + arch_dict["e"][i] = random.choice(self.expand_list) + + for i in range(self.n_stage): + if random.random() < mutate_prob: + arch_dict["d"][i] = random.choice(self.depth_list) + return arch_dict + + +class ResNetArchEncoder: + def __init__( + self, + image_size_list=None, + depth_list=None, + expand_list=None, + width_mult_list=None, + base_depth_list=None, + ): + self.image_size_list = [224] if image_size_list is None else image_size_list + self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list + self.depth_list = [0, 1, 2] if depth_list is None else depth_list + self.width_mult_list = ( + [0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list + ) + + self.base_depth_list = ( + ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list + ) + + """" build info dict """ + self.n_dim = 0 + # resolution + self.r_info = dict(id2val={}, val2id={}, L=[], R=[]) + self._build_info_dict(target="r") + # input stem skip + self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[]) + self._build_info_dict(target="input_stem_d") + # width_mult + self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[]) + self._build_info_dict(target="width_mult") + # expand ratio + self.e_info = dict(id2val=[], val2id=[], L=[], R=[]) + self._build_info_dict(target="e") + + @property + def n_stage(self): + return len(self.base_depth_list) + + @property + def max_n_blocks(self): + return sum(self.base_depth_list) + self.n_stage * max(self.depth_list) + + def _build_info_dict(self, target): + if target == "r": + target_dict = self.r_info + target_dict["L"].append(self.n_dim) + for img_size in self.image_size_list: + target_dict["val2id"][img_size] = self.n_dim + target_dict["id2val"][self.n_dim] = img_size + self.n_dim += 1 + target_dict["R"].append(self.n_dim) + elif target == "input_stem_d": + target_dict = self.input_stem_d_info + target_dict["L"].append(self.n_dim) + for skip in [0, 1]: + target_dict["val2id"][skip] = self.n_dim + target_dict["id2val"][self.n_dim] = skip + self.n_dim += 1 + target_dict["R"].append(self.n_dim) + elif target == "e": + target_dict = self.e_info + choices = self.expand_list + for i in range(self.max_n_blocks): + target_dict["val2id"].append({}) + target_dict["id2val"].append({}) + target_dict["L"].append(self.n_dim) + for e in choices: + target_dict["val2id"][i][e] = self.n_dim + target_dict["id2val"][i][self.n_dim] = e + self.n_dim += 1 + target_dict["R"].append(self.n_dim) + elif target == "width_mult": + target_dict = self.width_mult_info + choices = list(range(len(self.width_mult_list))) + for i in range(self.n_stage + 2): + target_dict["val2id"].append({}) + target_dict["id2val"].append({}) + target_dict["L"].append(self.n_dim) + for w in choices: + target_dict["val2id"][i][w] = self.n_dim + target_dict["id2val"][i][self.n_dim] = w + self.n_dim += 1 + target_dict["R"].append(self.n_dim) + + def arch2feature(self, arch_dict): + d, e, w, r = ( + arch_dict["d"], + arch_dict["e"], + arch_dict["w"], + arch_dict["image_size"], + ) + input_stem_skip = 1 if d[0] > 0 else 0 + d = d[1:] + + feature = np.zeros(self.n_dim) + feature[self.r_info["val2id"][r]] = 1 + feature[self.input_stem_d_info["val2id"][input_stem_skip]] = 1 + for i in range(self.n_stage + 2): + feature[self.width_mult_info["val2id"][i][w[i]]] = 1 + + start_pt = 0 + for i, base_depth in enumerate(self.base_depth_list): + depth = base_depth + d[i] + for j in range(start_pt, start_pt + depth): + feature[self.e_info["val2id"][j][e[j]]] = 1 + start_pt += max(self.depth_list) + base_depth + return feature + + def feature2arch(self, feature): + img_sz = self.r_info["id2val"][ + int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]])) + + self.r_info["L"][0] + ] + input_stem_skip = ( + self.input_stem_d_info["id2val"][ + int( + np.argmax( + feature[ + self.input_stem_d_info["L"][0] : self.input_stem_d_info[ + "R" + ][0] + ] + ) + ) + + self.input_stem_d_info["L"][0] + ] + * 2 + ) + assert img_sz in self.image_size_list + arch_dict = {"d": [input_stem_skip], "e": [], "w": [], "image_size": img_sz} + + for i in range(self.n_stage + 2): + arch_dict["w"].append( + self.width_mult_info["id2val"][i][ + int( + np.argmax( + feature[ + self.width_mult_info["L"][i] : self.width_mult_info[ + "R" + ][i] + ] + ) + ) + + self.width_mult_info["L"][i] + ] + ) + + d = 0 + skipped = 0 + stage_id = 0 + for i in range(self.max_n_blocks): + skip = True + for j in range(self.e_info["L"][i], self.e_info["R"][i]): + if feature[j] == 1: + arch_dict["e"].append(self.e_info["id2val"][i][j]) + skip = False + break + if skip: + arch_dict["e"].append(0) + skipped += 1 + else: + d += 1 + + if ( + i + 1 == self.max_n_blocks + or (skipped + d) + % (max(self.depth_list) + self.base_depth_list[stage_id]) + == 0 + ): + arch_dict["d"].append(d - self.base_depth_list[stage_id]) + d, skipped = 0, 0 + stage_id += 1 + return arch_dict + + def random_sample_arch(self): + return { + "d": [random.choice([0, 2])] + + random.choices(self.depth_list, k=self.n_stage), + "e": random.choices(self.expand_list, k=self.max_n_blocks), + "w": random.choices( + list(range(len(self.width_mult_list))), k=self.n_stage + 2 + ), + "image_size": random.choice(self.image_size_list), + } + + def mutate_resolution(self, arch_dict, mutate_prob): + if random.random() < mutate_prob: + arch_dict["image_size"] = random.choice(self.image_size_list) + return arch_dict + + def mutate_arch(self, arch_dict, mutate_prob): + # input stem skip + if random.random() < mutate_prob: + arch_dict["d"][0] = random.choice([0, 2]) + # depth + for i in range(1, len(arch_dict["d"])): + if random.random() < mutate_prob: + arch_dict["d"][i] = random.choice(self.depth_list) + # width_mult + for i in range(len(arch_dict["w"])): + if random.random() < mutate_prob: + arch_dict["w"][i] = random.choice( + list(range(len(self.width_mult_list))) + ) + # expand ratio + for i in range(len(arch_dict["e"])): + if random.random() < mutate_prob: + arch_dict["e"][i] = random.choice(self.expand_list) diff --git a/proard/nas/accuracy_predictor/rob_dataset.py b/proard/nas/accuracy_predictor/rob_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..00b6af0961ab558abf8c833c619bbd1ead6a23f2 --- /dev/null +++ b/proard/nas/accuracy_predictor/rob_dataset.py @@ -0,0 +1,211 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import json +import numpy as np +from tqdm import tqdm +import torch +import torch.utils.data + +from proard.utils import list_mean + +__all__ = ["net_setting2id", "net_id2setting", "RobustnessDataset"] + + +def net_setting2id(net_setting): + return json.dumps(net_setting) + + +def net_id2setting(net_id): + return json.loads(net_id) + + +class RegDataset(torch.utils.data.Dataset): + def __init__(self, inputs, targets): + super(RegDataset, self).__init__() + self.inputs = inputs + self.targets = targets + + def __getitem__(self, index): + return self.inputs[index], self.targets[index] + + def __len__(self): + return self.inputs.size(0) + + +class RobustnessDataset: + def __init__(self, path): + self.path = path + os.makedirs(self.path, exist_ok=True) + + @property + def net_id_path(self): + return os.path.join(self.path, "net_id.dict") + + @property + def rob_src_folder(self): + return os.path.join(self.path, "src_rob") + @property + def rob_dict_path(self): + return os.path.join(self.path, "src_rob/rob.dict") + + # TODO: support parallel building + def build_rob_dataset( + self, run_manager, dyn_network, n_arch=2000, image_size_list=None + ): + # load net_id_list, random sample if not exist + if os.path.isfile(self.net_id_path): + net_id_list = json.load(open(self.net_id_path)) + else: + net_id_list = set() + while len(net_id_list) < n_arch: + net_setting = dyn_network.sample_active_subnet() + net_id = net_setting2id(net_setting) + net_id_list.add(net_id) + net_id_list = list(net_id_list) + net_id_list.sort() + json.dump(net_id_list, open(self.net_id_path, "w"), indent=4) + + image_size_list = ( + [128, 160, 192, 224] if image_size_list is None else image_size_list + ) + + with tqdm( + total=len(net_id_list) * len(image_size_list), desc="Building Robustness Dataset" + ) as t: + for image_size in image_size_list: + # load val dataset into memory + val_dataset = [] + run_manager.run_config.data_provider.assign_active_img_size(image_size) + for images, labels in run_manager.run_config.valid_loader: + val_dataset.append((images, labels)) + # save path + os.makedirs(self.rob_src_folder, exist_ok=True) + + rob_save_path = os.path.join( + self.rob_src_folder, "%d.dict" % image_size + ) + + rob_dict ={} + # load existing rob dict + if os.path.isfile(rob_save_path): + existing_rob_dict = json.load(open(rob_save_path,"r")) + else: + existing_rob_dict = {} + for net_id in net_id_list: + net_setting = net_id2setting(net_id) + key = net_setting2id({**net_setting, "image_size": image_size}) + if key in existing_rob_dict: + rob_dict[key] = existing_rob_dict[key] + t.set_postfix( + { + "net_id": net_id, + "image_size": image_size, + "info_rob" : rob_dict[key], + "status": "loading", + } + ) + t.update() + continue + dyn_network.set_active_subnet(**net_setting) + run_manager.reset_running_statistics(dyn_network) + net_setting_str = ",".join( + [ + "%s_%s" + % ( + key, + "%.1f" % list_mean(val) + if isinstance(val, list) + else val, + ) + for key, val in net_setting.items() + ] + ) + loss, (top1, top5,robust1,robust5) = run_manager.validate( + run_str=net_setting_str, + net=dyn_network, + data_loader=val_dataset, + no_logs=True, + ) + info_robust = robust1 + t.set_postfix( + { + "net_id": net_id, + "image_size": image_size, + "info_rob" : info_robust, + "info_robust" : info_robust, + } + ) + t.update() + + rob_dict.update({key:info_robust}) + json.dump(rob_dict, open(rob_save_path, "w"), indent=4) + + def merge_rob_dataset(self, image_size_list=None): + # load existing data + merged_rob_dict = {} + for fname in os.listdir(self.rob_src_folder): + if ".dict" not in fname: + continue + image_size = int(fname.split(".dict")[0]) + if image_size_list is not None and image_size not in image_size_list: + print("Skip ", fname) + continue + full_path = os.path.join(self.rob_src_folder, fname) + partial_rob_dict = json.load(open(full_path)) + merged_rob_dict.update(partial_rob_dict) + print("loaded %s" % full_path) + json.dump(merged_rob_dict, open(self.rob_dict_path, "w"), indent=4) + return merged_rob_dict + + def build_rob_data_loader( + self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16 + ): + # load data + rob_dict = json.load(open(self.rob_dict_path)) + X_all_rob = [] + Y_all_rob = [] + with tqdm(total=len(rob_dict), desc="Loading data") as t: + for k, v in rob_dict.items(): + dic = json.loads(k) + X_all_rob.append(arch_encoder.arch2feature(dic)) + Y_all_rob.append(v / 100.0) # range: 0 - 1 + t.update() + base_rob = np.mean(Y_all_rob) + # convert to torch tensor + X_all_rob = torch.tensor(X_all_rob, dtype=torch.float) + Y_all_rob = torch.tensor(Y_all_rob) + + # random shuffle + shuffle_idx_rob = torch.randperm(len(X_all_rob)) + X_all_rob = X_all_rob[shuffle_idx_rob] + Y_all_rob = Y_all_rob[shuffle_idx_rob] + # split data + idx_rob = X_all_rob.size(0) // 5 * 4 if n_training_sample is None else n_training_sample + val_idx_rob = X_all_rob.size(0) // 5 * 4 + X_train_rob, Y_train_rob = X_all_rob[:idx_rob], Y_all_rob[:idx_rob] + X_test_rob, Y_test_rob = X_all_rob[val_idx_rob:], Y_all_rob[val_idx_rob:] + print("Train Robustness Size: %d," % len(X_train_rob), "Valid Robustness Size: %d" % len(X_test_rob)) + # build data loader + train_dataset_rob = RegDataset(X_train_rob, Y_train_rob) + val_dataset_rob = RegDataset(X_test_rob, Y_test_rob) + + train_loader_rob = torch.utils.data.DataLoader( + train_dataset_rob, + batch_size=batch_size, + shuffle=True, + pin_memory=False, + num_workers=n_workers, + ) + valid_loader_rob = torch.utils.data.DataLoader( + val_dataset_rob, + batch_size=batch_size, + shuffle=False, + pin_memory=False, + num_workers=n_workers, + ) + return train_loader_rob, valid_loader_rob , base_rob + + diff --git a/proard/nas/accuracy_predictor/rob_predictor.py b/proard/nas/accuracy_predictor/rob_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..078dd6695008c22ba4ca38045f3894f3d12d61ca --- /dev/null +++ b/proard/nas/accuracy_predictor/rob_predictor.py @@ -0,0 +1,66 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import numpy as np +import torch +import torch.nn as nn + +__all__ = ["RobustnessPredictor"] + +class RobustnessPredictor(nn.Module): + def __init__( + self, + arch_encoder, + hidden_size=400, + n_layers=3, + checkpoint_path=None, + device="cuda:0", + base_rob_val = None, + ): + super(RobustnessPredictor, self).__init__() + self.arch_encoder = arch_encoder + self.hidden_size = hidden_size + self.n_layers = n_layers + self.device = device + self.base_rob_val = base_rob_val + # build layers + layers = [] + for i in range(self.n_layers): + layers.append( + nn.Sequential( + nn.Linear( + self.arch_encoder.n_dim if i == 0 else self.hidden_size, + self.hidden_size, + ), + nn.ReLU(inplace=True), + ) + ) + layers.append(nn.Linear(self.hidden_size, 1, bias=False)) + self.layers = nn.Sequential(*layers) + if self.base_rob_val !=None : + self.base_rob = nn.Parameter( + torch.tensor(self.base_rob_val,device=self.device), requires_grad=False + ) + else: + self.base_rob = nn.Parameter( + torch.zeros(1, device=self.device), requires_grad=False + ) + if checkpoint_path is not None and os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + self.load_state_dict(checkpoint) + print("Loaded checkpoint from %s" % checkpoint_path) + + self.layers = self.layers.to(self.device) + + def forward(self, x): + y = self.layers(x).squeeze() + return y + self.base_rob + + def predict_rob(self, arch_dict_list): + X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list] + X = torch.tensor(np.array(X)).float().to(self.device) + return self.forward(X) \ No newline at end of file diff --git a/proard/nas/efficiency_predictor/__init__.py b/proard/nas/efficiency_predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cc68df4dc21abc06244c68fb1db8081348e023 --- /dev/null +++ b/proard/nas/efficiency_predictor/__init__.py @@ -0,0 +1,78 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import os +import copy +from .latency_lookup_table import * + + +class BaseEfficiencyModel: + def __init__(self, dyn_net): + self.dyn_net = dyn_net + + def get_active_subnet_config(self, arch_dict): + arch_dict = copy.deepcopy(arch_dict) + image_size = arch_dict.pop("image_size") + self.dyn_net.set_active_subnet(**arch_dict) + active_net_config = self.dyn_net.get_active_net_config() + return active_net_config, image_size + + def get_efficiency(self, arch_dict): + raise NotImplementedError + + +class ProxylessNASFLOPsModel(BaseEfficiencyModel): + def get_efficiency(self, arch_dict): + active_net_config, image_size = self.get_active_subnet_config(arch_dict) + return ProxylessNASLatencyTable.count_flops_given_config( + active_net_config, image_size + ) + + +class Mbv3FLOPsModel(BaseEfficiencyModel): + def get_efficiency(self, arch_dict): + active_net_config, image_size = self.get_active_subnet_config(arch_dict) + return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size[0]) + + +class ResNet50FLOPsModel(BaseEfficiencyModel): + def get_efficiency(self, arch_dict): + active_net_config, image_size = self.get_active_subnet_config(arch_dict) + return ResNet50LatencyTable.count_flops_given_config( + active_net_config, image_size + ) + + +class ProxylessNASLatencyModel(BaseEfficiencyModel): + def __init__(self, dyn_net, lookup_table_path_dict): + super(ProxylessNASLatencyModel, self).__init__(dyn_net) + self.latency_tables = {} + for image_size, path in lookup_table_path_dict.items(): + self.latency_tables[image_size] = ProxylessNASLatencyTable( + local_dir="/tmp/.dyn_latency_tools/", + url=os.path.join(path, "%d_lookup_table.yaml" % image_size), + ) + + def get_efficiency(self, arch_dict): + active_net_config, image_size = self.get_active_subnet_config(arch_dict) + return self.latency_tables[image_size].predict_network_latency_given_config( + active_net_config, image_size + ) + + +class Mbv3LatencyModel(BaseEfficiencyModel): + def __init__(self, dyn_net, lookup_table_path_dict): + super(Mbv3LatencyModel, self).__init__(dyn_net) + self.latency_tables = {} + for image_size, path in lookup_table_path_dict.items(): + self.latency_tables[image_size] = MBv3LatencyTable( + local_dir="/tmp/.dyn_latency_tools/", + url=os.path.join(path, "%d_lookup_table.yaml" % image_size), + ) + + def get_efficiency(self, arch_dict): + active_net_config, image_size = self.get_active_subnet_config(arch_dict) + return self.latency_tables[image_size].predict_network_latency_given_config( + active_net_config, image_size + ) diff --git a/proard/nas/efficiency_predictor/latency_lookup_table.py b/proard/nas/efficiency_predictor/latency_lookup_table.py new file mode 100644 index 0000000000000000000000000000000000000000..aa41668d55af61d2aa38912ef78a4693fcd939ba --- /dev/null +++ b/proard/nas/efficiency_predictor/latency_lookup_table.py @@ -0,0 +1,567 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import yaml +from proard.utils import download_url, make_divisible, MyNetwork + +__all__ = [ + "count_conv_flop", + "ProxylessNASLatencyTable", + "MBv3LatencyTable", + "ResNet50LatencyTable", +] + + +def count_conv_flop(out_size, in_channels, out_channels, kernel_size, groups): + out_h = out_w = out_size + delta_ops = ( + in_channels * out_channels * kernel_size * kernel_size * out_h * out_w / groups + ) + return delta_ops + + +class LatencyTable(object): + def __init__( + self, + local_dir="~/.dyn/latency_tools/", + url="https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/mobile_trim.yaml", + ): + if url.startswith("http"): + fname = download_url(url, local_dir, overwrite=True) + else: + fname = url + with open(fname, "r") as fp: + self.lut = yaml.load(fp) + + @staticmethod + def repr_shape(shape): + if isinstance(shape, (list, tuple)): + return "x".join(str(_) for _ in shape) + elif isinstance(shape, str): + return shape + else: + return TypeError + + def query(self, **kwargs): + raise NotImplementedError + + def predict_network_latency(self, net, image_size): + raise NotImplementedError + + def predict_network_latency_given_config(self, net_config, image_size): + raise NotImplementedError + + @staticmethod + def count_flops_given_config(net_config, image_size=224): + raise NotImplementedError + + +class ProxylessNASLatencyTable(LatencyTable): + def query( + self, + l_type: str, + input_shape, + output_shape, + expand=None, + ks=None, + stride=None, + id_skip=None, + ): + """ + :param l_type: + Layer type must be one of the followings + 1. `Conv`: The initial 3x3 conv with stride 2. + 2. `Conv_1`: feature_mix_layer + 3. `Logits`: All operations after `Conv_1`. + 4. `expanded_conv`: MobileInvertedResidual + :param input_shape: input shape (h, w, #channels) + :param output_shape: output shape (h, w, #channels) + :param expand: expansion ratio + :param ks: kernel size + :param stride: + :param id_skip: indicate whether has the residual connection + """ + infos = [ + l_type, + "input:%s" % self.repr_shape(input_shape), + "output:%s" % self.repr_shape(output_shape), + ] + + if l_type in ("expanded_conv",): + assert None not in (expand, ks, stride, id_skip) + infos += [ + "expand:%d" % expand, + "kernel:%d" % ks, + "stride:%d" % stride, + "idskip:%d" % id_skip, + ] + key = "-".join(infos) + return self.lut[key]["mean"] + + def predict_network_latency(self, net, image_size=224): + predicted_latency = 0 + # first conv + predicted_latency += self.query( + "Conv", + [image_size, image_size, 3], + [(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels], + ) + # blocks + fsize = (image_size + 1) // 2 + for block in net.blocks: + mb_conv = block.conv + shortcut = block.shortcut + + if mb_conv is None: + continue + if shortcut is None: + idskip = 0 + else: + idskip = 1 + out_fz = int((fsize - 1) / mb_conv.stride + 1) # fsize // mb_conv.stride + block_latency = self.query( + "expanded_conv", + [fsize, fsize, mb_conv.in_channels], + [out_fz, out_fz, mb_conv.out_channels], + expand=mb_conv.expand_ratio, + ks=mb_conv.kernel_size, + stride=mb_conv.stride, + id_skip=idskip, + ) + predicted_latency += block_latency + fsize = out_fz + # feature mix layer + predicted_latency += self.query( + "Conv_1", + [fsize, fsize, net.feature_mix_layer.in_channels], + [fsize, fsize, net.feature_mix_layer.out_channels], + ) + # classifier + predicted_latency += self.query( + "Logits", + [fsize, fsize, net.classifier.in_features], + [net.classifier.out_features], # 1000 + ) + return predicted_latency + + def predict_network_latency_given_config(self, net_config, image_size=224): + predicted_latency = 0 + # first conv + predicted_latency += self.query( + "Conv", + [image_size, image_size, 3], + [ + (image_size + 1) // 2, + (image_size + 1) // 2, + net_config["first_conv"]["out_channels"], + ], + ) + # blocks + fsize = (image_size + 1) // 2 + for block in net_config["blocks"]: + mb_conv = ( + block["mobile_inverted_conv"] + if "mobile_inverted_conv" in block + else block["conv"] + ) + shortcut = block["shortcut"] + + if mb_conv is None: + continue + if shortcut is None: + idskip = 0 + else: + idskip = 1 + out_fz = int((fsize - 1) / mb_conv["stride"] + 1) + block_latency = self.query( + "expanded_conv", + [fsize, fsize, mb_conv["in_channels"]], + [out_fz, out_fz, mb_conv["out_channels"]], + expand=mb_conv["expand_ratio"], + ks=mb_conv["kernel_size"], + stride=mb_conv["stride"], + id_skip=idskip, + ) + predicted_latency += block_latency + fsize = out_fz + # feature mix layer + predicted_latency += self.query( + "Conv_1", + [fsize, fsize, net_config["feature_mix_layer"]["in_channels"]], + [fsize, fsize, net_config["feature_mix_layer"]["out_channels"]], + ) + # classifier + predicted_latency += self.query( + "Logits", + [fsize, fsize, net_config["classifier"]["in_features"]], + [net_config["classifier"]["out_features"]], # 1000 + ) + return predicted_latency + + @staticmethod + def count_flops_given_config(net_config, image_size=224): + flops = 0 + # first conv + flops += count_conv_flop( + (image_size + 1) // 2, 3, net_config["first_conv"]["out_channels"], 3, 1 + ) + # blocks + fsize = (image_size + 1) // 2 + for block in net_config["blocks"]: + mb_conv = ( + block["mobile_inverted_conv"] + if "mobile_inverted_conv" in block + else block["conv"] + ) + if mb_conv is None: + continue + out_fz = int((fsize - 1) / mb_conv["stride"] + 1) + if mb_conv["mid_channels"] is None: + mb_conv["mid_channels"] = round( + mb_conv["in_channels"] * mb_conv["expand_ratio"] + ) + if mb_conv["expand_ratio"] != 1: + # inverted bottleneck + flops += count_conv_flop( + fsize, mb_conv["in_channels"], mb_conv["mid_channels"], 1, 1 + ) + # depth conv + flops += count_conv_flop( + out_fz, + mb_conv["mid_channels"], + mb_conv["mid_channels"], + mb_conv["kernel_size"], + mb_conv["mid_channels"], + ) + # point linear + flops += count_conv_flop( + out_fz, mb_conv["mid_channels"], mb_conv["out_channels"], 1, 1 + ) + fsize = out_fz + # feature mix layer + flops += count_conv_flop( + fsize, + net_config["feature_mix_layer"]["in_channels"], + net_config["feature_mix_layer"]["out_channels"], + 1, + 1, + ) + # classifier + flops += count_conv_flop( + 1, + net_config["classifier"]["in_features"], + net_config["classifier"]["out_features"], + 1, + 1, + ) + return flops / 1e6 # MFLOPs + + +class MBv3LatencyTable(LatencyTable): + def query( + self, + l_type: str, + input_shape, + output_shape, + mid=None, + ks=None, + stride=None, + id_skip=None, + se=None, + h_swish=None, + ): + infos = [ + l_type, + "input:%s" % self.repr_shape(input_shape), + "output:%s" % self.repr_shape(output_shape), + ] + + if l_type in ("expanded_conv",): + assert None not in (mid, ks, stride, id_skip, se, h_swish) + infos += [ + "expand:%d" % mid, + "kernel:%d" % ks, + "stride:%d" % stride, + "idskip:%d" % id_skip, + "se:%d" % se, + "hs:%d" % h_swish, + ] + key = "-".join(infos) + return self.lut[key]["mean"] + + def predict_network_latency(self, net, image_size=224): + predicted_latency = 0 + # first conv + predicted_latency += self.query( + "Conv", + [image_size, image_size, 3], + [(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels], + ) + # blocks + fsize = (image_size + 1) // 2 + for block in net.blocks: + mb_conv = block.conv + shortcut = block.shortcut + + if mb_conv is None: + continue + if shortcut is None: + idskip = 0 + else: + idskip = 1 + out_fz = int((fsize - 1) / mb_conv.stride + 1) + block_latency = self.query( + "expanded_conv", + [fsize, fsize, mb_conv.in_channels], + [out_fz, out_fz, mb_conv.out_channels], + mid=mb_conv.depth_conv.conv.in_channels, + ks=mb_conv.kernel_size, + stride=mb_conv.stride, + id_skip=idskip, + se=1 if mb_conv.use_se else 0, + h_swish=1 if mb_conv.act_func == "h_swish" else 0, + ) + predicted_latency += block_latency + fsize = out_fz + # final expand layer + predicted_latency += self.query( + "Conv_1", + [fsize, fsize, net.final_expand_layer.in_channels], + [fsize, fsize, net.final_expand_layer.out_channels], + ) + # global average pooling + predicted_latency += self.query( + "AvgPool2D", + [fsize, fsize, net.final_expand_layer.out_channels], + [1, 1, net.final_expand_layer.out_channels], + ) + # feature mix layer + predicted_latency += self.query( + "Conv_2", + [1, 1, net.feature_mix_layer.in_channels], + [1, 1, net.feature_mix_layer.out_channels], + ) + # classifier + predicted_latency += self.query( + "Logits", [1, 1, net.classifier.in_features], [net.classifier.out_features] + ) + return predicted_latency + + def predict_network_latency_given_config(self, net_config, image_size=224): + predicted_latency = 0 + # first conv + predicted_latency += self.query( + "Conv", + [image_size, image_size, 3], + [ + (image_size + 1) // 2, + (image_size + 1) // 2, + net_config["first_conv"]["out_channels"], + ], + ) + # blocks + fsize = (image_size + 1) // 2 + for block in net_config["blocks"]: + mb_conv = ( + block["mobile_inverted_conv"] + if "mobile_inverted_conv" in block + else block["conv"] + ) + shortcut = block["shortcut"] + + if mb_conv is None: + continue + if shortcut is None: + idskip = 0 + else: + idskip = 1 + out_fz = int((fsize - 1) / mb_conv["stride"] + 1) + if mb_conv["mid_channels"] is None: + mb_conv["mid_channels"] = round( + mb_conv["in_channels"] * mb_conv["expand_ratio"] + ) + block_latency = self.query( + "expanded_conv", + [fsize, fsize, mb_conv["in_channels"]], + [out_fz, out_fz, mb_conv["out_channels"]], + mid=mb_conv["mid_channels"], + ks=mb_conv["kernel_size"], + stride=mb_conv["stride"], + id_skip=idskip, + se=1 if mb_conv["use_se"] else 0, + h_swish=1 if mb_conv["act_func"] == "h_swish" else 0, + ) + predicted_latency += block_latency + fsize = out_fz + # final expand layer + predicted_latency += self.query( + "Conv_1", + [fsize, fsize, net_config["final_expand_layer"]["in_channels"]], + [fsize, fsize, net_config["final_expand_layer"]["out_channels"]], + ) + # global average pooling + predicted_latency += self.query( + "AvgPool2D", + [fsize, fsize, net_config["final_expand_layer"]["out_channels"]], + [1, 1, net_config["final_expand_layer"]["out_channels"]], + ) + # feature mix layer + predicted_latency += self.query( + "Conv_2", + [1, 1, net_config["feature_mix_layer"]["in_channels"]], + [1, 1, net_config["feature_mix_layer"]["out_channels"]], + ) + # classifier + predicted_latency += self.query( + "Logits", + [1, 1, net_config["classifier"]["in_features"]], + [net_config["classifier"]["out_features"]], + ) + return predicted_latency + + @staticmethod + def count_flops_given_config(net_config, image_size=224): + flops = 0 + # first conv + flops += count_conv_flop( + (image_size + 1) // 2, 3, net_config["first_conv"]["out_channels"], 3, 1 + ) + # blocks + fsize = (image_size + 1) // 2 + for block in net_config["blocks"]: + mb_conv = ( + block["mobile_inverted_conv"] + if "mobile_inverted_conv" in block + else block["conv"] + ) + if mb_conv is None: + continue + out_fz = int((fsize - 1) / mb_conv["stride"] + 1) + if mb_conv["mid_channels"] is None: + mb_conv["mid_channels"] = round( + mb_conv["in_channels"] * mb_conv["expand_ratio"] + ) + if mb_conv["expand_ratio"] != 1: + # inverted bottleneck + flops += count_conv_flop( + fsize, mb_conv["in_channels"], mb_conv["mid_channels"], 1, 1 + ) + # depth conv + flops += count_conv_flop( + out_fz, + mb_conv["mid_channels"], + mb_conv["mid_channels"], + mb_conv["kernel_size"], + mb_conv["mid_channels"], + ) + if mb_conv["use_se"]: + # SE layer + se_mid = make_divisible( + mb_conv["mid_channels"] // 4, divisor=MyNetwork.CHANNEL_DIVISIBLE + ) + flops += count_conv_flop(1, mb_conv["mid_channels"], se_mid, 1, 1) + flops += count_conv_flop(1, se_mid, mb_conv["mid_channels"], 1, 1) + # point linear + flops += count_conv_flop( + out_fz, mb_conv["mid_channels"], mb_conv["out_channels"], 1, 1 + ) + fsize = out_fz + # final expand layer + flops += count_conv_flop( + fsize, + net_config["final_expand_layer"]["in_channels"], + net_config["final_expand_layer"]["out_channels"], + 1, + 1, + ) + # feature mix layer + flops += count_conv_flop( + 1, + net_config["feature_mix_layer"]["in_channels"], + net_config["feature_mix_layer"]["out_channels"], + 1, + 1, + ) + # classifier + flops += count_conv_flop( + 1, + net_config["classifier"]["in_features"], + net_config["classifier"]["out_features"], + 1, + 1, + ) + return flops / 1e6 # MFLOPs + + +class ResNet50LatencyTable(LatencyTable): + def query(self, **kwargs): + raise NotImplementedError + + def predict_network_latency(self, net, image_size): + raise NotImplementedError + + def predict_network_latency_given_config(self, net_config, image_size): + raise NotImplementedError + + @staticmethod + def count_flops_given_config(net_config, image_size=32): + flops = 0 + # input stem + for layer_config in net_config["input_stem"]: + if layer_config["name"] != "ConvLayer": + layer_config = layer_config["conv"] + in_channel = layer_config["in_channels"] + out_channel = layer_config["out_channels"] + out_image_size = int((image_size - 1) / layer_config["stride"] + 1) + + flops += count_conv_flop( + out_image_size, + in_channel, + out_channel, + layer_config["kernel_size"], + layer_config.get("groups", 1), + ) + image_size = out_image_size + # max pooling + # image_size = int((image_size - 1) / 2 + 1) + # ResNetBottleneckBlocks + for block_config in net_config["blocks"]: + in_channel = block_config["in_channels"] + out_channel = block_config["out_channels"] + + out_image_size = int((image_size - 1) / block_config["stride"] + 1) + mid_channel = ( + block_config["mid_channels"] + if block_config["mid_channels"] is not None + else round(out_channel * block_config["expand_ratio"]) + ) + mid_channel = make_divisible(mid_channel, MyNetwork.CHANNEL_DIVISIBLE) + + # conv1 + flops += count_conv_flop(image_size, in_channel, mid_channel, 1, 1) + # conv2 + flops += count_conv_flop( + out_image_size, + mid_channel, + mid_channel, + block_config["kernel_size"], + block_config["groups"], + ) + # conv3 + flops += count_conv_flop(out_image_size, mid_channel, out_channel, 1, 1) + # downsample + if block_config["stride"] == 1 and in_channel == out_channel: + pass + else: + flops += count_conv_flop(out_image_size, in_channel, out_channel, 1, 1) + image_size = out_image_size + # final classifier + flops += count_conv_flop( + 1, + net_config["classifier"]["in_features"], + net_config["classifier"]["out_features"], + 1, + 1, + ) + return flops / 1e6 # MFLOPs diff --git a/proard/nas/search_algorithm/__init__.py b/proard/nas/search_algorithm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8fd6b36e3634812755e87973a5b03a21214f --- /dev/null +++ b/proard/nas/search_algorithm/__init__.py @@ -0,0 +1,6 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .evolution import * +from .multi_evolution import * diff --git a/proard/nas/search_algorithm/evolution.py b/proard/nas/search_algorithm/evolution.py new file mode 100644 index 0000000000000000000000000000000000000000..0371e97b3e71e890c4ce28b4a67940d17dab97d3 --- /dev/null +++ b/proard/nas/search_algorithm/evolution.py @@ -0,0 +1,143 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import copy +import random +import numpy as np +from tqdm import tqdm + +__all__ = ["EvolutionFinder"] + + +class EvolutionFinder: + def __init__(self, efficiency_predictor, accuracy_predictor, Robustness_predictor, **kwargs): + self.efficiency_predictor = efficiency_predictor + self.accuracy_predictor = accuracy_predictor + self.robustness_predictor = Robustness_predictor + + # evolution hyper-parameters + self.arch_mutate_prob = kwargs.get("arch_mutate_prob", 0.1) + self.resolution_mutate_prob = kwargs.get("resolution_mutate_prob", 0.5) + self.population_size = kwargs.get("population_size", 100) + self.max_time_budget = kwargs.get("max_time_budget", 500) + self.parent_ratio = kwargs.get("parent_ratio", 0.25) + self.mutation_ratio = kwargs.get("mutation_ratio", 0.5) + + @property + def arch_manager(self): + return self.accuracy_predictor.arch_encoder + + def update_hyper_params(self, new_param_dict): + self.__dict__.update(new_param_dict) + + def random_valid_sample(self, constraint): + while True: + sample = self.arch_manager.random_sample_arch() + efficiency = self.efficiency_predictor.get_efficiency(sample) + if efficiency <= constraint: + return sample, efficiency + + def mutate_sample(self, sample, constraint): + while True: + new_sample = copy.deepcopy(sample) + self.arch_manager.mutate_resolution(new_sample, self.resolution_mutate_prob) + self.arch_manager.mutate_arch(new_sample, self.arch_mutate_prob) + + efficiency = self.efficiency_predictor.get_efficiency(new_sample) + if efficiency <= constraint: + return new_sample, efficiency + + def crossover_sample(self, sample1, sample2, constraint): + while True: + new_sample = copy.deepcopy(sample1) + for key in new_sample.keys(): + if not isinstance(new_sample[key], list): + new_sample[key] = random.choice([sample1[key], sample2[key]]) + else: + for i in range(len(new_sample[key])): + new_sample[key][i] = random.choice( + [sample1[key][i], sample2[key][i]] + ) + + efficiency = self.efficiency_predictor.get_efficiency(new_sample) + if efficiency <= constraint: + return new_sample, efficiency + + def run_evolution_search(self, constraint, verbose=False, **kwargs): + """Run a single roll-out of regularized evolution to a fixed time budget.""" + self.update_hyper_params(kwargs) + + mutation_numbers = int(round(self.mutation_ratio * self.population_size)) + parents_size = int(round(self.parent_ratio * self.population_size)) + + best_valids = [-100] + population = [] # (validation, robustness, sample, latency) tuples + child_pool = [] + efficiency_pool = [] + best_info = None + if verbose: + print("Generate random population...") + for _ in range(self.population_size): + sample, efficiency = self.random_valid_sample(constraint) + child_pool.append(sample) + efficiency_pool.append(efficiency) + + accs = self.accuracy_predictor.predict_acc(child_pool) + robs = self.robustness_predictor.predict_rob(child_pool) + for i in range(self.population_size): + population.append((accs[i].item(), robs[i].item(), child_pool[i], efficiency_pool[i])) + + if verbose: + print("Start Evolution...") + # After the population is seeded, proceed with evolving the population. + with tqdm( + total=self.max_time_budget, + desc="Searching with constraint (%s)" % constraint, + disable=(not verbose), + ) as t: + for i in range(self.max_time_budget): + parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size] + acc = parents[0][0] + rob = parents[0][1] + t.set_postfix({"acc": parents[0][0] , "rob":parents[0][1]}) + if not verbose and (i + 1) % 100 == 0: + print("Iter: {} Acc: {} Rob: {}".format(i + 1, parents[0][0],parents[0][1])) + + if acc > best_valids[-1]: + best_valids.append(acc) + best_info = parents[0] + else: + best_valids.append(best_valids[-1]) + + population = parents + child_pool = [] + efficiency_pool = [] + + for j in range(mutation_numbers): + par_sample = population[np.random.randint(parents_size)][2] + # Mutate + new_sample, efficiency = self.mutate_sample(par_sample, constraint) + child_pool.append(new_sample) + efficiency_pool.append(efficiency) + + for j in range(self.population_size - mutation_numbers): + par_sample1 = population[np.random.randint(parents_size)][2] + par_sample2 = population[np.random.randint(parents_size)][2] + # Crossover + new_sample, efficiency = self.crossover_sample( + par_sample1, par_sample2, constraint + ) + child_pool.append(new_sample) + efficiency_pool.append(efficiency) + + accs = self.accuracy_predictor.predict_acc(child_pool) + robs = self.robustness_predictor.predict_rob(child_pool) + for j in range(self.population_size): + population.append( + (accs[j].item(), robs[j].item(), child_pool[j], efficiency_pool[j]) + ) + + t.update(1) + + return best_valids, best_info diff --git a/proard/nas/search_algorithm/multi_evolution.py b/proard/nas/search_algorithm/multi_evolution.py new file mode 100644 index 0000000000000000000000000000000000000000..157599892b79b5a964fd9ec045c38faf9e036f68 --- /dev/null +++ b/proard/nas/search_algorithm/multi_evolution.py @@ -0,0 +1,143 @@ +import numpy as np +from pymoo.core.individual import Individual +from pymoo.core.problem import Problem +from pymoo.core.sampling import Sampling +from pymoo.core.variable import Choice +__all__ = ["individual_to_arch_mbv","DynIndividual_mbv","DynProblem_mbv","individual_to_arch_res","DynIndividual_res","DynProblem_res","DynSampling","DynRandomSampler"] +def individual_to_arch_mbv(population, n_blocks): + archs = [] + for individual in population: + archs.append( + { + "ks": individual[0:n_blocks], + "e": individual[n_blocks : 2 * n_blocks], + "d": individual[2 * n_blocks : -1], + "image_size": individual[-1:], + } + ) + return archs +class DynIndividual_mbv(Individual): + def __init__(self, individual, accuracy_predictor,Robustness_predictor, config=None, **kwargs): + super().__init__(config=None, **kwargs) + self.X = np.concatenate( + ( + individual[0]["ks"], + individual[0]["e"], + individual[0]["d"], + individual[0]["image_size"], + ) + ) + self.flops = individual[1] + self.accuracy = 100 - accuracy_predictor.predict_acc([individual[0]]) + self.robustness = 100 - Robustness_predictor.predict_rob([individual[0]]) + self.F = np.concatenate(([self.flops], [self.accuracy.squeeze().cpu().detach().numpy()],[self.robustness.squeeze().cpu().detach().numpy()])) + + + +class DynProblem_mbv(Problem): + def __init__(self, efficiency_predictor, accuracy_predictor, robustness_predictor, num_blocks, num_stages, search_vars): + self.ks = Choice(options=search_vars.get('ks')) + self.e = Choice(options=search_vars.get('e')) + self.d = Choice(options=search_vars.get('d')) + self.r = Choice(options=search_vars.get('image_size')) + + super().__init__( + vars= dict(zip(range(len(num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), + n_obj=3, + n_constr=0, + ) + self.efficiency_predictor = efficiency_predictor + self.accuracy_predictor = accuracy_predictor + self.robustness_predictor = robustness_predictor + self.blocks = num_blocks + self.stages = num_stages + self.search_vars = search_vars + + def _evaluate(self, x, out, *args, **kwargs): + f1=[] + # x.shape = (population_size, n_var) = (100, 4) + arch = individual_to_arch_mbv(x, self.blocks) + for arc in arch: + f1.append(self.efficiency_predictor.get_efficiency(arc)) + f2 = 100 - self.accuracy_predictor.predict_acc(arch).detach().cpu().numpy() + f3 = 100 - self.robustness_predictor.predict_rob(arch).detach().cpu().numpy() + out["F"] = np.column_stack([f1, f2,f3]) + + +def individual_to_arch_res(population, n_blocks): + archs = [] + for individual in population: + archs.append( + { + "e": individual[n_blocks : 2 * n_blocks], + "d": individual[2 * n_blocks : -1], + "w": individual[0:n_blocks], + "r": individual[-1:], + } + ) + return archs +class DynIndividual_res(Individual): + def __init__(self, individual, accuracy_predictor,Robustness_predictor, config=None, **kwargs): + super().__init__(config=None, **kwargs) + self.X = np.concatenate( + ( + individual[0]["e"], + individual[0]["d"], + individual[0]["w"], + [individual[0]["image_size"]], + ) + ) + self.flops = individual[1] + self.accuracy = 100 - accuracy_predictor.predict_acc([individual[0]]) + self.robustness = 100 - Robustness_predictor.predict_rob([individual[0]]) + self.F = np.concatenate(([self.flops], [self.accuracy.squeeze().cpu().detach().numpy()],[self.robustness.squeeze().cpu().detach().numpy()])) + + + +class DynProblem_res(Problem): + def __init__(self, efficiency_predictor, accuracy_predictor, robustness_predictor, num_blocks, num_stages, search_vars): + self.e = Choice(options=search_vars.get('e')) + self.d = Choice(options=search_vars.get('d')) + self.w = Choice(options=search_vars.get('w')) + self.r = Choice(options=search_vars.get('image_size')) + super().__init__( + vars= dict(zip(range(len(num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), num_blocks * [self.ks] + num_blocks * [self.e] + num_stages * [self.d] + [self.r])), + n_obj=3, + n_constr=0, + ) + self.efficiency_predictor = efficiency_predictor + self.accuracy_predictor = accuracy_predictor + self.robustness_predictor = robustness_predictor + self.blocks = num_blocks + self.stages = num_stages + self.search_vars = search_vars + + def _evaluate(self, x, out, *args, **kwargs): + f1={} + # x.shape = (population_size, n_var) = (100, 4) + arch = individual_to_arch_res(x, self.blocks) + for arc in arch: + f1.append(self.efficiency_predictor.get_efficiency(arc)) + f2 = 100 - self.accuracy_predictor.predict_acc(arch) + f3 = 100 - self.robustness_predictor.predict_rob(arch) + out["F"] = np.column_stack([f1, f2,f3]) + + + +class DynSampling(Sampling): + def _do(self, problem, n_samples, **kwargs): + return [ + [np.random.choice(var.options) for key,var in problem.vars.items()] + for _ in range(n_samples) + ] + + +class DynRandomSampler: + def __init__(self, arch_manager, efficiency_predictor): + self.arch_manager = arch_manager + self.efficiency_predictor = efficiency_predictor + + def random_sample(self): + sample = self.arch_manager.random_sample_arch() + efficiency = self.efficiency_predictor.get_efficiency(sample) + return sample, efficiency \ No newline at end of file diff --git a/proard/utils/__init__.py b/proard/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1839557ed721041b0e6844ea8b4c1c4b42a69bb8 --- /dev/null +++ b/proard/utils/__init__.py @@ -0,0 +1,10 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +from .pytorch_modules import * +from .pytorch_utils import * +from .my_modules import * +from .flops_counter import * +from .common_tools import * +from .my_dataloader import * diff --git a/proard/utils/common_tools.py b/proard/utils/common_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..01a3899d8381e95a93d54455500db0b5f06bac33 --- /dev/null +++ b/proard/utils/common_tools.py @@ -0,0 +1,307 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import numpy as np +import os +import sys +import torch + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + +__all__ = [ + "sort_dict", + "get_same_padding", + "get_split_list", + "list_sum", + "list_mean", + "list_join", + "subset_mean", + "sub_filter_start_end", + "min_divisible_value", + "val2list", + "download_url", + "write_log", + "pairwise_accuracy", + "accuracy", + "AverageMeter", + "MultiClassAverageMeter", + "DistributedMetric", + "DistributedTensor", +] + + +def sort_dict(src_dict, reverse=False, return_dict=True): + output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse) + if return_dict: + return dict(output) + else: + return output + + +def get_same_padding(kernel_size): + if isinstance(kernel_size, tuple): + assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size + p1 = get_same_padding(kernel_size[0]) + p2 = get_same_padding(kernel_size[1]) + return p1, p2 + assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`" + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +def get_split_list(in_dim, child_num, accumulate=False): + in_dim_list = [in_dim // child_num] * child_num + for _i in range(in_dim % child_num): + in_dim_list[_i] += 1 + if accumulate: + for i in range(1, child_num): + in_dim_list[i] += in_dim_list[i - 1] + return in_dim_list + + +def list_sum(x): + return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) + + +def list_mean(x): + return list_sum(x) / len(x) + + +def list_join(val_list, sep="\t"): + return sep.join([str(val) for val in val_list]) + + +def subset_mean(val_list, sub_indexes): + sub_indexes = val2list(sub_indexes, 1) + return list_mean([val_list[idx] for idx in sub_indexes]) + + +def sub_filter_start_end(kernel_size, sub_kernel_size): + center = kernel_size // 2 + dev = sub_kernel_size // 2 + start, end = center - dev, center + dev + 1 + assert end - start == sub_kernel_size + return start, end + + +def min_divisible_value(n1, v1): + """make sure v1 is divisible by n1, otherwise decrease v1""" + if v1 >= n1: + return n1 + while n1 % v1 != 0: + v1 -= 1 + return v1 + + +def val2list(val, repeat_time=1): + if isinstance(val, list) or isinstance(val, np.ndarray): + return val + elif isinstance(val, tuple): + return list(val) + else: + return [val for _ in range(repeat_time)] + + +def download_url(url, model_dir="~/.torch/", overwrite=False): + target_dir = url.split("/")[-1] + model_dir = os.path.expanduser(model_dir) + try: + if not os.path.exists(model_dir): + os.makedirs(model_dir) + model_dir = os.path.join(model_dir, target_dir) + cached_file = model_dir + if not os.path.exists(cached_file) or overwrite: + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return cached_file + except Exception as e: + # remove lock file so download can be executed next time. + os.remove(os.path.join(model_dir, "download.lock")) + sys.stderr.write("Failed to download from url %s" % url + "\n" + str(e) + "\n") + return None + + +def write_log(logs_path, log_str, prefix="valid", should_print=True, mode="a"): + if not os.path.exists(logs_path): + os.makedirs(logs_path, exist_ok=True) + """ prefix: valid, train, test """ + if prefix in ["valid", "test"]: + with open(os.path.join(logs_path, "valid_console.txt"), mode) as fout: + fout.write(log_str + "\n") + fout.flush() + if prefix in ["valid", "test", "train"]: + with open(os.path.join(logs_path, "train_console.txt"), mode) as fout: + if prefix in ["valid", "test"]: + fout.write("=" * 10) + fout.write(log_str + "\n") + fout.flush() + else: + with open(os.path.join(logs_path, "%s.txt" % prefix), mode) as fout: + fout.write(log_str + "\n") + fout.flush() + if should_print: + print(log_str) + + +def pairwise_accuracy(la, lb, n_samples=200000): + n = len(la) + assert n == len(lb) + total = 0 + count = 0 + for _ in range(n_samples): + i = np.random.randint(n) + j = np.random.randint(n) + while i == j: + j = np.random.randint(n) + if la[i] >= la[j] and lb[i] >= lb[j]: + count += 1 + if la[i] < la[j] and lb[i] < lb[j]: + count += 1 + total += 1 + return float(count) / total + + + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class AverageMeter(object): + """ + Computes and stores the average and current value + Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py + """ + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class MultiClassAverageMeter: + + """Multi Binary Classification Tasks""" + + def __init__(self, num_classes, balanced=False, **kwargs): + + super(MultiClassAverageMeter, self).__init__() + self.num_classes = num_classes + self.balanced = balanced + + self.counts = [] + for k in range(self.num_classes): + self.counts.append(np.ndarray((2, 2), dtype=np.float32)) + + self.reset() + + def reset(self): + for k in range(self.num_classes): + self.counts[k].fill(0) + + def add(self, outputs, targets): + outputs = outputs.data.cpu().numpy() + targets = targets.data.cpu().numpy() + + for k in range(self.num_classes): + output = np.argmax(outputs[:, k, :], axis=1) + target = targets[:, k] + + x = output + 2 * target + bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2) + + self.counts[k] += bincount.reshape((2, 2)) + + def value(self): + mean = 0 + for k in range(self.num_classes): + if self.balanced: + value = np.mean( + ( + self.counts[k] + / np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None] + ).diagonal() + ) + else: + value = np.sum(self.counts[k].diagonal()) / np.maximum( + np.sum(self.counts[k]), 1 + ) + + mean += value / self.num_classes * 100.0 + return mean + + +class DistributedMetric(object): + """ + Horovod: average metrics from distributed training. + """ + + def __init__(self, name): + self.name = name + self.sum = torch.zeros(1)[0] + self.count = torch.zeros(1)[0] + + def update(self, val, delta_n=1): + import horovod.torch as hvd + + val *= delta_n + self.sum += hvd.allreduce(val.detach().cpu(), name=self.name) + self.count += delta_n + + @property + def avg(self): + return self.sum / self.count + + +class DistributedTensor(object): + def __init__(self, name): + self.name = name + self.sum = None + self.count = torch.zeros(1)[0] + self.synced = False + + def update(self, val, delta_n=1): + val *= delta_n + if self.sum is None: + self.sum = val.detach() + else: + self.sum += val.detach() + self.count += delta_n + + @property + def avg(self): + import horovod.torch as hvd + + if not self.synced: + self.sum = hvd.allreduce(self.sum, name=self.name) + self.synced = True + return self.sum / self.count diff --git a/proard/utils/flops_counter.py b/proard/utils/flops_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..b236a525d6380de249f893be0478e38f9c235ac4 --- /dev/null +++ b/proard/utils/flops_counter.py @@ -0,0 +1,97 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import torch +import torch.nn as nn + +from .my_modules import MyConv2d + +__all__ = ["profile"] + + +def count_convNd(m, _, y): + cin = m.in_channels + + kernel_ops = m.weight.size()[2] * m.weight.size()[3] + ops_per_element = kernel_ops + output_elements = y.nelement() + + # cout x oW x oH + total_ops = cin * output_elements * ops_per_element // m.groups + m.total_ops = torch.zeros(1).fill_(total_ops) + + +def count_linear(m, _, __): + total_ops = m.in_features * m.out_features + + m.total_ops = torch.zeros(1).fill_(total_ops) + + +register_hooks = { + nn.Conv1d: count_convNd, + nn.Conv2d: count_convNd, + nn.Conv3d: count_convNd, + MyConv2d: count_convNd, + ###################################### + nn.Linear: count_linear, + ###################################### + nn.Dropout: None, + nn.Dropout2d: None, + nn.Dropout3d: None, + nn.BatchNorm2d: None, +} + + +def profile(model, input_size, custom_ops=None): + handler_collection = [] + custom_ops = {} if custom_ops is None else custom_ops + + def add_hooks(m_): + if len(list(m_.children())) > 0: + return + + m_.register_buffer("total_ops", torch.zeros(1)) + m_.register_buffer("total_params", torch.zeros(1)) + + for p in m_.parameters(): + m_.total_params += torch.zeros(1).fill_(p.numel()) + + m_type = type(m_) + fn = None + + if m_type in custom_ops: + fn = custom_ops[m_type] + elif m_type in register_hooks: + fn = register_hooks[m_type] + + if fn is not None: + _handler = m_.register_forward_hook(fn) + handler_collection.append(_handler) + + original_device = model.parameters().__next__().device + training = model.training + + model.eval() + model.apply(add_hooks) + + x = torch.zeros(input_size).to(original_device) + with torch.no_grad(): + model(x) + + total_ops = 0 + total_params = 0 + for m in model.modules(): + if len(list(m.children())) > 0: # skip for non-leaf module + continue + total_ops += m.total_ops + total_params += m.total_params + + total_ops = total_ops.item() + total_params = total_params.item() + + model.train(training).to(original_device) + for handler in handler_collection: + handler.remove() + + return total_ops, total_params diff --git a/proard/utils/layers.py b/proard/utils/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..9e05d7337760fc5bf937262067a7aa18dfba7379 --- /dev/null +++ b/proard/utils/layers.py @@ -0,0 +1,819 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import torch +import torch.nn as nn + +from collections import OrderedDict +from proard.utils import get_same_padding, min_divisible_value, SEModule, ShuffleLayer +from proard.utils import MyNetwork, MyModule +from proard.utils import build_activation, make_divisible + +__all__ = [ + "set_layer_from_config", + "ConvLayer", + "IdentityLayer", + "LinearLayer", + "MultiHeadLinearLayer", + "ZeroLayer", + "MBConvLayer", + "ResidualBlock", + "ResNetBottleneckBlock", +] + + +def set_layer_from_config(layer_config): + if layer_config is None: + return None + + name2layer = { + ConvLayer.__name__: ConvLayer, + IdentityLayer.__name__: IdentityLayer, + LinearLayer.__name__: LinearLayer, + MultiHeadLinearLayer.__name__: MultiHeadLinearLayer, + ZeroLayer.__name__: ZeroLayer, + MBConvLayer.__name__: MBConvLayer, + "MBInvertedConvLayer": MBConvLayer, + ########################################################## + ResidualBlock.__name__: ResidualBlock, + ResNetBottleneckBlock.__name__: ResNetBottleneckBlock, + } + + layer_name = layer_config.pop("name") + layer = name2layer[layer_name] + return layer.build_from_config(layer_config) + + +class My2DLayer(MyModule): + def __init__( + self, + in_channels, + out_channels, + use_bn=True, + act_func="relu", + dropout_rate=0, + ops_order="weight_bn_act", + ): + super(My2DLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.use_bn = use_bn + self.act_func = act_func + self.dropout_rate = dropout_rate + self.ops_order = ops_order + + """ modules """ + modules = {} + # batch norm + if self.use_bn: + if self.bn_before_weight: + modules["bn"] = nn.BatchNorm2d(in_channels) + else: + modules["bn"] = nn.BatchNorm2d(out_channels) + else: + modules["bn"] = None + # activation + modules["act"] = build_activation( + self.act_func, self.ops_list[0] != "act" and self.use_bn + ) + # dropout + if self.dropout_rate > 0: + modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True) + else: + modules["dropout"] = None + # weight + modules["weight"] = self.weight_op() + + # add modules + for op in self.ops_list: + if modules[op] is None: + continue + elif op == "weight": + # dropout before weight operation + if modules["dropout"] is not None: + self.add_module("dropout", modules["dropout"]) + for key in modules["weight"]: + self.add_module(key, modules["weight"][key]) + else: + self.add_module(op, modules[op]) + + @property + def ops_list(self): + return self.ops_order.split("_") + + @property + def bn_before_weight(self): + for op in self.ops_list: + if op == "bn": + return True + elif op == "weight": + return False + raise ValueError("Invalid ops_order: %s" % self.ops_order) + + def weight_op(self): + raise NotImplementedError + + """ Methods defined in MyModule """ + + def forward(self, x): + # similar to nn.Sequential + for module in self._modules.values(): + x = module(x) + return x + + @property + def module_str(self): + raise NotImplementedError + + @property + def config(self): + return { + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "use_bn": self.use_bn, + "act_func": self.act_func, + "dropout_rate": self.dropout_rate, + "ops_order": self.ops_order, + } + + @staticmethod + def build_from_config(config): + raise NotImplementedError + + +class ConvLayer(My2DLayer): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=False, + has_shuffle=False, + use_se=False, + use_bn=True, + act_func="relu", + dropout_rate=0, + ops_order="weight_bn_act", + ): + # default normal 3x3_Conv with bn and relu + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.groups = groups + self.bias = bias + self.has_shuffle = has_shuffle + self.use_se = use_se + + super(ConvLayer, self).__init__( + in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order + ) + if self.use_se: + self.add_module("se", SEModule(self.out_channels)) + + def weight_op(self): + padding = get_same_padding(self.kernel_size) + if isinstance(padding, int): + padding *= self.dilation + else: + padding[0] *= self.dilation + padding[1] *= self.dilation + + weight_dict = OrderedDict( + { + "conv": nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=padding, + dilation=self.dilation, + groups=min_divisible_value(self.in_channels, self.groups), + bias=self.bias, + ) + } + ) + if self.has_shuffle and self.groups > 1: + weight_dict["shuffle"] = ShuffleLayer(self.groups) + + return weight_dict + + @property + def module_str(self): + if isinstance(self.kernel_size, int): + kernel_size = (self.kernel_size, self.kernel_size) + else: + kernel_size = self.kernel_size + if self.groups == 1: + if self.dilation > 1: + conv_str = "%dx%d_DilatedConv" % (kernel_size[0], kernel_size[1]) + else: + conv_str = "%dx%d_Conv" % (kernel_size[0], kernel_size[1]) + else: + if self.dilation > 1: + conv_str = "%dx%d_DilatedGroupConv" % (kernel_size[0], kernel_size[1]) + else: + conv_str = "%dx%d_GroupConv" % (kernel_size[0], kernel_size[1]) + conv_str += "_O%d" % self.out_channels + if self.use_se: + conv_str = "SE_" + conv_str + conv_str += "_" + self.act_func.upper() + if self.use_bn: + if isinstance(self.bn, nn.GroupNorm): + conv_str += "_GN%d" % self.bn.num_groups + elif isinstance(self.bn, nn.BatchNorm2d): + conv_str += "_BN" + return conv_str + + @property + def config(self): + return { + "name": ConvLayer.__name__, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "groups": self.groups, + "bias": self.bias, + "has_shuffle": self.has_shuffle, + "use_se": self.use_se, + **super(ConvLayer, self).config, + } + + @staticmethod + def build_from_config(config): + return ConvLayer(**config) + + +class IdentityLayer(My2DLayer): + def __init__( + self, + in_channels, + out_channels, + use_bn=False, + act_func=None, + dropout_rate=0, + ops_order="weight_bn_act", + ): + super(IdentityLayer, self).__init__( + in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order + ) + + def weight_op(self): + return None + + @property + def module_str(self): + return "Identity" + + @property + def config(self): + return { + "name": IdentityLayer.__name__, + **super(IdentityLayer, self).config, + } + + @staticmethod + def build_from_config(config): + return IdentityLayer(**config) + + +class LinearLayer(MyModule): + def __init__( + self, + in_features, + out_features, + bias=True, + use_bn=False, + act_func=None, + dropout_rate=0, + ops_order="weight_bn_act", + ): + super(LinearLayer, self).__init__() + + self.in_features = in_features + self.out_features = out_features + self.bias = bias + + self.use_bn = use_bn + self.act_func = act_func + self.dropout_rate = dropout_rate + self.ops_order = ops_order + + """ modules """ + modules = {} + # batch norm + if self.use_bn: + if self.bn_before_weight: + modules["bn"] = nn.BatchNorm1d(in_features) + else: + modules["bn"] = nn.BatchNorm1d(out_features) + else: + modules["bn"] = None + # activation + modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act") + # dropout + if self.dropout_rate > 0: + modules["dropout"] = nn.Dropout(self.dropout_rate, inplace=True) + else: + modules["dropout"] = None + # linear + modules["weight"] = { + "linear": nn.Linear(self.in_features, self.out_features, self.bias) + } + + # add modules + for op in self.ops_list: + if modules[op] is None: + continue + elif op == "weight": + if modules["dropout"] is not None: + self.add_module("dropout", modules["dropout"]) + for key in modules["weight"]: + self.add_module(key, modules["weight"][key]) + else: + self.add_module(op, modules[op]) + + @property + def ops_list(self): + return self.ops_order.split("_") + + @property + def bn_before_weight(self): + for op in self.ops_list: + if op == "bn": + return True + elif op == "weight": + return False + raise ValueError("Invalid ops_order: %s" % self.ops_order) + + def forward(self, x): + for module in self._modules.values(): + x = module(x) + return x + + @property + def module_str(self): + return "%dx%d_Linear" % (self.in_features, self.out_features) + + @property + def config(self): + return { + "name": LinearLayer.__name__, + "in_features": self.in_features, + "out_features": self.out_features, + "bias": self.bias, + "use_bn": self.use_bn, + "act_func": self.act_func, + "dropout_rate": self.dropout_rate, + "ops_order": self.ops_order, + } + + @staticmethod + def build_from_config(config): + return LinearLayer(**config) + + +class MultiHeadLinearLayer(MyModule): + def __init__( + self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0 + ): + super(MultiHeadLinearLayer, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.num_heads = num_heads + + self.bias = bias + self.dropout_rate = dropout_rate + + if self.dropout_rate > 0: + self.dropout = nn.Dropout(self.dropout_rate, inplace=True) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for k in range(num_heads): + layer = nn.Linear(in_features, out_features, self.bias) + self.layers.append(layer) + + def forward(self, inputs): + if self.dropout is not None: + inputs = self.dropout(inputs) + + outputs = [] + for layer in self.layers: + output = layer.forward(inputs) + outputs.append(output) + + outputs = torch.stack(outputs, dim=1) + return outputs + + @property + def module_str(self): + return self.__repr__() + + @property + def config(self): + return { + "name": MultiHeadLinearLayer.__name__, + "in_features": self.in_features, + "out_features": self.out_features, + "num_heads": self.num_heads, + "bias": self.bias, + "dropout_rate": self.dropout_rate, + } + + @staticmethod + def build_from_config(config): + return MultiHeadLinearLayer(**config) + + def __repr__(self): + return ( + "MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)" + % ( + self.in_features, + self.out_features, + self.num_heads, + self.bias, + self.dropout_rate, + ) + ) + + +class ZeroLayer(MyModule): + def __init__(self): + super(ZeroLayer, self).__init__() + + def forward(self, x): + raise ValueError + + @property + def module_str(self): + return "Zero" + + @property + def config(self): + return { + "name": ZeroLayer.__name__, + } + + @staticmethod + def build_from_config(config): + return ZeroLayer() + + +class MBConvLayer(MyModule): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + expand_ratio=6, + mid_channels=None, + act_func="relu6", + use_se=False, + groups=None, + ): + super(MBConvLayer, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.kernel_size = kernel_size + self.stride = stride + self.expand_ratio = expand_ratio + self.mid_channels = mid_channels + self.act_func = act_func + self.use_se = use_se + self.groups = groups + + if self.mid_channels is None: + feature_dim = round(self.in_channels * self.expand_ratio) + else: + feature_dim = self.mid_channels + + if self.expand_ratio == 1: + self.inverted_bottleneck = None + else: + self.inverted_bottleneck = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + self.in_channels, feature_dim, 1, 1, 0, bias=False + ), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + pad = get_same_padding(self.kernel_size) + groups = ( + feature_dim + if self.groups is None + else min_divisible_value(feature_dim, self.groups) + ) + depth_conv_modules = [ + ( + "conv", + nn.Conv2d( + feature_dim, + feature_dim, + kernel_size, + stride, + pad, + groups=groups, + bias=False, + ), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + if self.use_se: + depth_conv_modules.append(("se", SEModule(feature_dim))) + self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules)) + + self.point_linear = nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)), + ("bn", nn.BatchNorm2d(out_channels)), + ] + ) + ) + + def forward(self, x): + if self.inverted_bottleneck: + x = self.inverted_bottleneck(x) + x = self.depth_conv(x) + x = self.point_linear(x) + return x + + @property + def module_str(self): + if self.mid_channels is None: + expand_ratio = self.expand_ratio + else: + expand_ratio = self.mid_channels // self.in_channels + layer_str = "%dx%d_MBConv%d_%s" % ( + self.kernel_size, + self.kernel_size, + expand_ratio, + self.act_func.upper(), + ) + if self.use_se: + layer_str = "SE_" + layer_str + layer_str += "_O%d" % self.out_channels + if self.groups is not None: + layer_str += "_G%d" % self.groups + if isinstance(self.point_linear.bn, nn.GroupNorm): + layer_str += "_GN%d" % self.point_linear.bn.num_groups + elif isinstance(self.point_linear.bn, nn.BatchNorm2d): + layer_str += "_BN" + + return layer_str + + @property + def config(self): + return { + "name": MBConvLayer.__name__, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "stride": self.stride, + "expand_ratio": self.expand_ratio, + "mid_channels": self.mid_channels, + "act_func": self.act_func, + "use_se": self.use_se, + "groups": self.groups, + } + + @staticmethod + def build_from_config(config): + return MBConvLayer(**config) + + +class ResidualBlock(MyModule): + def __init__(self, conv, shortcut): + super(ResidualBlock, self).__init__() + + self.conv = conv + self.shortcut = shortcut + + def forward(self, x): + if self.conv is None or isinstance(self.conv, ZeroLayer): + res = x + elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer): + res = self.conv(x) + else: + res = self.conv(x) + self.shortcut(x) + return res + + @property + def module_str(self): + return "(%s, %s)" % ( + self.conv.module_str if self.conv is not None else None, + self.shortcut.module_str if self.shortcut is not None else None, + ) + + @property + def config(self): + return { + "name": ResidualBlock.__name__, + "conv": self.conv.config if self.conv is not None else None, + "shortcut": self.shortcut.config if self.shortcut is not None else None, + } + + @staticmethod + def build_from_config(config): + conv_config = ( + config["conv"] if "conv" in config else config["mobile_inverted_conv"] + ) + conv = set_layer_from_config(conv_config) + shortcut = set_layer_from_config(config["shortcut"]) + return ResidualBlock(conv, shortcut) + + @property + def mobile_inverted_conv(self): + return self.conv + + +class ResNetBottleneckBlock(MyModule): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + expand_ratio=0.25, + mid_channels=None, + act_func="relu", + groups=1, + downsample_mode="avgpool_conv", + ): + super(ResNetBottleneckBlock, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.kernel_size = kernel_size + self.stride = stride + self.expand_ratio = expand_ratio + self.mid_channels = mid_channels + self.act_func = act_func + self.groups = groups + + self.downsample_mode = downsample_mode + + if self.mid_channels is None: + feature_dim = round(self.out_channels * self.expand_ratio) + else: + feature_dim = self.mid_channels + + feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE) + self.mid_channels = feature_dim + + # build modules + self.conv1 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + pad = get_same_padding(self.kernel_size) + self.conv2 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + feature_dim, + feature_dim, + kernel_size, + stride, + pad, + groups=groups, + bias=False, + ), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + self.conv3 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False), + ), + ("bn", nn.BatchNorm2d(self.out_channels)), + ] + ) + ) + + if stride == 1 and in_channels == out_channels: + self.downsample = IdentityLayer(in_channels, out_channels) + elif self.downsample_mode == "conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + in_channels, out_channels, 1, stride, 0, bias=False + ), + ), + ("bn", nn.BatchNorm2d(out_channels)), + ] + ) + ) + elif self.downsample_mode == "avgpool_conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "avg_pool", + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + padding=0, + ceil_mode=True, + ), + ), + ( + "conv", + nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False), + ), + ("bn", nn.BatchNorm2d(out_channels)), + ] + ) + ) + else: + raise NotImplementedError + + self.final_act = build_activation(self.act_func, inplace=True) + + def forward(self, x): + residual = self.downsample(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + + x = x + residual + x = self.final_act(x) + return x + + @property + def module_str(self): + return "(%s, %s)" % ( + "%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d" + % ( + self.kernel_size, + self.kernel_size, + self.in_channels, + self.mid_channels, + self.out_channels, + self.stride, + self.groups, + ), + "Identity" + if isinstance(self.downsample, IdentityLayer) + else self.downsample_mode, + ) + + @property + def config(self): + return { + "name": ResNetBottleneckBlock.__name__, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "stride": self.stride, + "expand_ratio": self.expand_ratio, + "mid_channels": self.mid_channels, + "act_func": self.act_func, + "groups": self.groups, + "downsample_mode": self.downsample_mode, + } + + @staticmethod + def build_from_config(config): + return ResNetBottleneckBlock(**config) diff --git a/proard/utils/my_dataloader/__init__.py b/proard/utils/my_dataloader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faee844377de7107c62ca10ecbfb4fd5412e4896 --- /dev/null +++ b/proard/utils/my_dataloader/__init__.py @@ -0,0 +1,2 @@ +from .my_distributed_sampler import * +from .my_random_resize_crop import * diff --git a/proard/utils/my_dataloader/my_data_loader.py b/proard/utils/my_dataloader/my_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8746ed3dae49561f403eb092d8680455e43112b7 --- /dev/null +++ b/proard/utils/my_dataloader/my_data_loader.py @@ -0,0 +1,1050 @@ +r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + +import threading +import itertools +import warnings +import multiprocessing as python_multiprocessing +import torch +import torch.multiprocessing as multiprocessing +from torch._utils import ExceptionWrapper +from torch.multiprocessing import Queue as queue +from torch._six import string_classes +from torch.utils.data.dataset import IterableDataset +from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler +from torch.utils.data import _utils + +from .my_data_worker import worker_loop + +__all__ = ["MyDataLoader"] + +get_worker_info = _utils.worker.get_worker_info + +# This function used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate = _utils.collate.default_collate + + +class _DatasetKind(object): + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self): + super(_InfiniteConstantSampler, self).__init__(None) + + def __iter__(self): + while True: + yield None + + +class MyDataLoader(object): + r""" + Data loader. Combines a dataset and a sampler, and provides an iterable over + the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Arguments: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler, optional): defines the strategy to draw samples from + the dataset. If specified, :attr:`shuffle` must be ``False``. + batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of + indices at a time. Mutually exclusive with :attr:`batch_size`, + :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + ``len(dataset)`` (if implemented) is returned instead, regardless + of multi-process loading configurations, because PyTorch trust + user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. See `Dataset Types`_ for more + details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_. + """ + + __initialized = False + + def __init__( + self, + dataset, + batch_size=1, + shuffle=False, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + multiprocessing_context=None, + ): + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + self.dataset = dataset + self.num_workers = num_workers + self.pin_memory = pin_memory + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and `IterableDataset` ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if shuffle is not False: + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "shuffle option, but got shuffle={}".format(shuffle) + ) + elif sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "sampler option, but got sampler={}".format(sampler) + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "batch_sampler option, but got batch_sampler={}".format( + batch_sampler + ) + ) + else: + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with " "shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if shuffle or drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with " + "shuffle, and drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if not multiprocessing._supports_context: + raise ValueError( + "multiprocessing_context relies on Python >= 3.4, with " + "support for different start methods" + ) + + if isinstance(multiprocessing_context, string_classes): + valid_start_methods = multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + ( + "multiprocessing_context option " + "should specify a valid start method in {}, but got " + "multiprocessing_context={}" + ).format(valid_start_methods, multiprocessing_context) + ) + multiprocessing_context = multiprocessing.get_context( + multiprocessing_context + ) + + if not isinstance( + multiprocessing_context, python_multiprocessing.context.BaseContext + ): + raise ValueError( + ( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + "multiprocessing_context={}" + ).format(multiprocessing_context) + ) + else: + raise ValueError( + ( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + "num_workers={}" + ).format(self.num_workers) + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + ): + raise ValueError( + "{} attribute should not be set after {} is " + "initialized".format(attr, self.__class__.__name__) + ) + + super(MyDataLoader, self).__setattr__(attr, val) + + def __iter__(self): + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + return _MultiProcessingDataLoaderIter(self) + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self): + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + length = self._IterableDataset_len_called = len(self.dataset) + return length + else: + return len(self._index_sampler) + + +class _BaseDataLoaderIter(object): + def __init__(self, loader): + self._dataset = loader.dataset + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + self._pin_memory = loader.pin_memory and torch.cuda.is_available() + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = torch.empty((), dtype=torch.int64).random_().item() + self._num_yielded = 0 + + def __iter__(self): + return self + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self): + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " + "samples have been fetched. " + ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg) + return data + + next = __next__ # Python 2 compatibility + + def __len__(self): + return len(self._index_sampler) + + def __getstate__(self): + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super(_SingleProcessDataLoaderIter, self).__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may alreay be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed. (Hooks freeing those + # resources are registered at importing the Python core libraries at + # the top of this file.) So in `__del__`, we check if + # `_utils.python_exit_status` is set or `None` (freed), and perform + # no-op if so. + # + # Another problem with `__del__` is also related to the library cleanup + # calls. When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # Another choice is to just shutdown workers with logic in 1 above + # whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # exits. + # + # However, in case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. Therefore, + # for both `worker_result_queue` (worker -> main process/pin_memory_thread) + # and each `index_queue` (main process -> worker), we use + # `q.cancel_join_thread()` in sender process before any `q.put` to + # prevent this automatic join. + # + # Moreover, having all queues called `cancel_join_thread` makes + # implementing graceful shutdown logic in `__del__` much easier. + # It won't need to get from any queue, which would also need to be + # guarded by periodic status checks. + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does + # a blocking `put` if the queue is full. So there is no above + # problem, but we do need to wrap the `put` in a loop that breaks + # not only upon success, but also when the main process stops + # reading, i.e., is shutting down. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `index_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timeing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super(_MultiProcessingDataLoaderIter, self).__init__(loader) + + assert self._num_workers > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + self._worker_result_queue = multiprocessing_context.Queue() + self._worker_pids_set = False + self._shutdown = False + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + self._workers_status = [] + for i in range(self._num_workers): + index_queue = multiprocessing_context.Queue() + # index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed + i, + self._worker_init_fn, + i, + self._num_workers, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + self._workers_status.append(True) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + self._data_queue = queue() + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + torch.cuda.current_device(), + self._pin_memory_thread_done_event, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue + + _utils.signal_handling._set_worker_pids( + id(self), tuple(w.pid for w in self._workers) + ) + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + + # prime the prefetch loop + for _ in range(2 * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._shutdown_worker(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError( + "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str) + ) + if isinstance(e, queue.Empty): + return (False, None) + raise + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError( + "DataLoader timed out after {} seconds".format(self._timeout) + ) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info[self._rcvd_idx] + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + data = self._task_info.pop(self._rcvd_idx)[1] + return self._process_data(data) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + self._shutdown_worker(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + # store out-of-order samples + self._task_info[idx] += (data,) + else: + del self._task_info[idx] + return self._process_data(data) + + def _try_put_index(self): + assert self._tasks_outstanding < 2 * self._num_workers + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) + self._task_info[self._send_idx] = (worker_queue_idx,) + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data): + self._rcvd_idx += 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _shutdown_worker(self, worker_id): + # Mark a worker as having finished its work and dead, e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + self._workers_status[worker_id] = False + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + python_exit_status = _utils.python_exit_status + if python_exit_status is True or python_exit_status is None: + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + if self._workers_status[worker_id]: + self._shutdown_worker(worker_id) + for w in self._workers: + w.join() + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + + def __del__(self): + self._shutdown_workers() diff --git a/proard/utils/my_dataloader/my_data_worker.py b/proard/utils/my_dataloader/my_data_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b44af9c33411814081d18723cab1503955fc65 --- /dev/null +++ b/proard/utils/my_dataloader/my_data_worker.py @@ -0,0 +1,242 @@ +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import torch +import random +import os +from collections import namedtuple +# from torch._six import queue +from torch.multiprocessing import Queue as queue +from torch._utils import ExceptionWrapper +from torch.utils.data._utils import ( + signal_handling, + MP_STATUS_CHECK_INTERVAL, + IS_WINDOWS, +) + +from .my_random_resize_crop import MyRandomResizedCrop + +__all__ = ["worker_loop"] + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import DWORD, BOOL, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + + self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess( + SYNCHRONIZE, 0, self.manager_pid + ) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = ( + self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + ) + return not self.manager_dead + + +else: + + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +_worker_info = None + + +class WorkerInfo(object): + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError( + "Cannot assign attributes to {} objects".format(self.__class__.__name__) + ) + return super(WorkerInfo, self).__setattr__(key, val) + + +def get_worker_info(): + r"""Returns the information about the current + :class:`~torch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~torch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~torch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code (e.g., NumPy). + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" +_IterableDatasetStopIteration = namedtuple( + "_IterableDatasetStopIteration", ["worker_id"] +) + + +def worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + seed, + init_fn, + worker_id, + num_workers, +): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.set_num_threads(1) + random.seed(seed) + torch.manual_seed(seed) + + global _worker_info + _worker_info = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset + ) + + from torch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + except Exception: + init_exception = ExceptionWrapper( + where="in DataLoader worker process {}".format(worker_id) + ) + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if r is None: + # Received the final signal + assert done_event.is_set() or iteration_end + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + """ Added """ + MyRandomResizedCrop.sample_image_size(idx) + """ Added """ + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) + except Exception as e: + if ( + isinstance(e, StopIteration) + and dataset_kind == _DatasetKind.Iterable + ): + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper( + where="in DataLoader worker process {}".format(worker_id) + ) + data_queue.put((idx, data)) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + data_queue.cancel_join_thread() + data_queue.close() diff --git a/proard/utils/my_dataloader/my_distributed_sampler.py b/proard/utils/my_dataloader/my_distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..66fae42c7b71e40de035c4bb3a0385ab7b3f4c58 --- /dev/null +++ b/proard/utils/my_dataloader/my_distributed_sampler.py @@ -0,0 +1,87 @@ +import math +import torch +from torch.utils.data.distributed import DistributedSampler + +__all__ = ["MyDistributedSampler", "WeightedDistributedSampler"] + + +class MyDistributedSampler(DistributedSampler): + """Allow Subset Sampler in Distributed Training""" + + def __init__( + self, dataset, num_replicas=None, rank=None, shuffle=True, sub_index_list=None + ): + super(MyDistributedSampler, self).__init__(dataset, num_replicas, rank, shuffle) + self.sub_index_list = sub_index_list # numpy + + self.num_samples = int( + math.ceil(len(self.sub_index_list) * 1.0 / self.num_replicas) + ) + self.total_size = self.num_samples * self.num_replicas + print("Use MyDistributedSampler: %d, %d" % (self.num_samples, self.total_size)) + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.sub_index_list), generator=g).tolist() + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + indices = self.sub_index_list[indices].tolist() + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + +class WeightedDistributedSampler(DistributedSampler): + """Allow Weighted Random Sampling in Distributed Training""" + + def __init__( + self, + dataset, + num_replicas=None, + rank=None, + shuffle=True, + weights=None, + replacement=True, + ): + super(WeightedDistributedSampler, self).__init__( + dataset, num_replicas, rank, shuffle + ) + + self.weights = ( + torch.as_tensor(weights, dtype=torch.double) + if weights is not None + else None + ) + self.replacement = replacement + print("Use WeightedDistributedSampler") + + def __iter__(self): + if self.weights is None: + return super(WeightedDistributedSampler, self).__iter__() + else: + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + # original: indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = torch.multinomial( + self.weights, len(self.dataset), self.replacement, generator=g + ).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/proard/utils/my_dataloader/my_random_resize_crop.py b/proard/utils/my_dataloader/my_random_resize_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..e055d5bdf16c4f7155a64e8ba93304747b9bc8a2 --- /dev/null +++ b/proard/utils/my_dataloader/my_random_resize_crop.py @@ -0,0 +1,161 @@ +import time +import random +import math +import os +from PIL import Image + +import torchvision.transforms.functional as F +import torchvision.transforms as transforms + +__all__ = ["MyRandomResizedCrop", "MyResizeRandomCrop", "MyResize"] + +_pil_interpolation_to_str = { + Image.NEAREST: "PIL.Image.NEAREST", + Image.BILINEAR: "PIL.Image.BILINEAR", + Image.BICUBIC: "PIL.Image.BICUBIC", + Image.LANCZOS: "PIL.Image.LANCZOS", + Image.HAMMING: "PIL.Image.HAMMING", + Image.BOX: "PIL.Image.BOX", +} + + +class MyRandomResizedCrop(transforms.RandomResizedCrop): + ACTIVE_SIZE = 224 + IMAGE_SIZE_LIST = [224] + IMAGE_SIZE_SEG = 4 + + CONTINUOUS = False + SYNC_DISTRIBUTED = True + + EPOCH = 0 + BATCH = 0 + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation=Image.BILINEAR, + ): + if not isinstance(size, int): + size = size[0] + super(MyRandomResizedCrop, self).__init__(size, scale, ratio, interpolation) + + def __call__(self, img): + i, j, h, w = self.get_params(img, self.scale, self.ratio) + return F.resized_crop( + img, + i, + j, + h, + w, + (MyRandomResizedCrop.ACTIVE_SIZE, MyRandomResizedCrop.ACTIVE_SIZE), + self.interpolation, + ) + + @staticmethod + def get_candidate_image_size(): + if MyRandomResizedCrop.CONTINUOUS: + min_size = min(MyRandomResizedCrop.IMAGE_SIZE_LIST) + max_size = max(MyRandomResizedCrop.IMAGE_SIZE_LIST) + candidate_sizes = [] + for i in range(min_size, max_size + 1): + if i % MyRandomResizedCrop.IMAGE_SIZE_SEG == 0: + candidate_sizes.append(i) + else: + candidate_sizes = MyRandomResizedCrop.IMAGE_SIZE_LIST + + relative_probs = None + return candidate_sizes, relative_probs + + @staticmethod + def sample_image_size(batch_id=None): + if batch_id is None: + batch_id = MyRandomResizedCrop.BATCH + if MyRandomResizedCrop.SYNC_DISTRIBUTED: + _seed = int("%d%.3d" % (batch_id, MyRandomResizedCrop.EPOCH)) + else: + _seed = os.getpid() + time.time() + random.seed(_seed) + candidate_sizes, relative_probs = MyRandomResizedCrop.get_candidate_image_size() + MyRandomResizedCrop.ACTIVE_SIZE = random.choices( + candidate_sizes, weights=relative_probs + )[0] + + def __repr__(self): + format_string = self.__class__.__name__ + "(size={0}".format( + MyRandomResizedCrop.IMAGE_SIZE_LIST + ) + if MyRandomResizedCrop.CONTINUOUS: + format_string += "@continuous" + format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) + return format_string + + +class MyResizeRandomCrop(object): + def __init__( + self, + interpolation=Image.BILINEAR, + use_padding=False, + pad_if_needed=False, + fill=0, + padding_mode="constant", + ): + # resize + self.interpolation = interpolation + # random crop + self.use_padding = use_padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def __call__(self, img): + crop_size = MyRandomResizedCrop.ACTIVE_SIZE + + if not self.use_padding: + resize_size = int(math.ceil(crop_size / 0.875)) + img = F.resize(img, resize_size, self.interpolation) + else: + img = F.resize(img, crop_size, self.interpolation) + padding_size = crop_size // 8 + img = F.pad(img, padding_size, self.fill, self.padding_mode) + + # pad the width if needed + if self.pad_if_needed and img.size[0] < crop_size: + img = F.pad(img, (crop_size - img.size[0], 0), self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and img.size[1] < crop_size: + img = F.pad(img, (0, crop_size - img.size[1]), self.fill, self.padding_mode) + + i, j, h, w = transforms.RandomCrop.get_params(img, (crop_size, crop_size)) + return F.crop(img, i, j, h, w) + + def __repr__(self): + return ( + "MyResizeRandomCrop(size=%s%s, interpolation=%s, use_padding=%s, fill=%s)" + % ( + MyRandomResizedCrop.IMAGE_SIZE_LIST, + "@continuous" if MyRandomResizedCrop.CONTINUOUS else "", + _pil_interpolation_to_str[self.interpolation], + self.use_padding, + self.fill, + ) + ) + + +class MyResize(object): + def __init__(self, interpolation=Image.BILINEAR): + self.interpolation = interpolation + + def __call__(self, img): + target_size = MyRandomResizedCrop.ACTIVE_SIZE + img = F.resize(img, target_size, self.interpolation) + return img + + def __repr__(self): + return "MyResize(size=%s%s, interpolation=%s)" % ( + MyRandomResizedCrop.IMAGE_SIZE_LIST, + "@continuous" if MyRandomResizedCrop.CONTINUOUS else "", + _pil_interpolation_to_str[self.interpolation], + ) diff --git a/proard/utils/my_modules.py b/proard/utils/my_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcac20bb77aa5b02dfe7493ecb0544675de26c7 --- /dev/null +++ b/proard/utils/my_modules.py @@ -0,0 +1,293 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import math +import torch.nn as nn +import torch.nn.functional as F + +from .common_tools import min_divisible_value + +__all__ = [ + "MyModule", + "MyNetwork", + "init_models", + "set_bn_param", + "get_bn_param", + "replace_bn_with_gn", + "MyConv2d", + "replace_conv2d_with_my_conv2d", +] + + +def set_bn_param(net, momentum, eps, gn_channel_per_group=None, ws_eps=None, **kwargs): + replace_bn_with_gn(net, gn_channel_per_group) + + for m in net.modules(): + if type(m) in [nn.BatchNorm1d, nn.BatchNorm2d]: + m.momentum = momentum + m.eps = eps + elif isinstance(m, nn.GroupNorm): + m.eps = eps + + replace_conv2d_with_my_conv2d(net, ws_eps) + return + + +def get_bn_param(net): + ws_eps = None + for m in net.modules(): + if isinstance(m, MyConv2d): + ws_eps = m.WS_EPS + break + for m in net.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + return { + "momentum": m.momentum, + "eps": m.eps, + "ws_eps": ws_eps, + } + elif isinstance(m, nn.GroupNorm): + return { + "momentum": None, + "eps": m.eps, + "gn_channel_per_group": m.num_channels // m.num_groups, + "ws_eps": ws_eps, + } + return None + + +def replace_bn_with_gn(model, gn_channel_per_group): + if gn_channel_per_group is None: + return + + for m in model.modules(): + to_replace_dict = {} + for name, sub_m in m.named_children(): + if isinstance(sub_m, nn.BatchNorm2d): + num_groups = sub_m.num_features // min_divisible_value( + sub_m.num_features, gn_channel_per_group + ) + gn_m = nn.GroupNorm( + num_groups=num_groups, + num_channels=sub_m.num_features, + eps=sub_m.eps, + affine=True, + ) + + # load weight + gn_m.weight.data.copy_(sub_m.weight.data) + gn_m.bias.data.copy_(sub_m.bias.data) + # load requires_grad + gn_m.weight.requires_grad = sub_m.weight.requires_grad + gn_m.bias.requires_grad = sub_m.bias.requires_grad + + to_replace_dict[name] = gn_m + m._modules.update(to_replace_dict) + + +def replace_conv2d_with_my_conv2d(net, ws_eps=None): + if ws_eps is None: + return + + for m in net.modules(): + to_update_dict = {} + for name, sub_module in m.named_children(): + if isinstance(sub_module, nn.Conv2d) and not sub_module.bias: + # only replace conv2d layers that are followed by normalization layers (i.e., no bias) + to_update_dict[name] = sub_module + for name, sub_module in to_update_dict.items(): + m._modules[name] = MyConv2d( + sub_module.in_channels, + sub_module.out_channels, + sub_module.kernel_size, + sub_module.stride, + sub_module.padding, + sub_module.dilation, + sub_module.groups, + sub_module.bias, + ) + # load weight + m._modules[name].load_state_dict(sub_module.state_dict()) + # load requires_grad + m._modules[name].weight.requires_grad = sub_module.weight.requires_grad + if sub_module.bias is not None: + m._modules[name].bias.requires_grad = sub_module.bias.requires_grad + # set ws_eps + for m in net.modules(): + if isinstance(m, MyConv2d): + m.WS_EPS = ws_eps + + +def init_models(net, model_init="he_fout"): + """ + Conv2d, + BatchNorm2d, BatchNorm1d, GroupNorm + Linear, + """ + if isinstance(net, list): + for sub_net in net: + init_models(sub_net, model_init) + return + for m in net.modules(): + if isinstance(m, nn.Conv2d): + if model_init == "he_fout": + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif model_init == "he_fin": + n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + else: + raise NotImplementedError + if m.bias is not None: + m.bias.data.zero_() + elif type(m) in [nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm]: + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + stdv = 1.0 / math.sqrt(m.weight.size(1)) + m.weight.data.uniform_(-stdv, stdv) + if m.bias is not None: + m.bias.data.zero_() + + +class MyConv2d(nn.Conv2d): + """ + Conv2d with Weight Standardization + https://github.com/joe-siyuan-qiao/WeightStandardization + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super(MyConv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.WS_EPS = None + + def weight_standardization(self, weight): + if self.WS_EPS is not None: + weight_mean = ( + weight.mean(dim=1, keepdim=True) + .mean(dim=2, keepdim=True) + .mean(dim=3, keepdim=True) + ) + weight = weight - weight_mean + std = ( + weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + + self.WS_EPS + ) + weight = weight / std.expand_as(weight) + return weight + + def forward(self, x): + if self.WS_EPS is None: + return super(MyConv2d, self).forward(x) + else: + return F.conv2d( + x, + self.weight_standardization(self.weight), + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def __repr__(self): + return super(MyConv2d, self).__repr__()[:-1] + ", ws_eps=%s)" % self.WS_EPS + + +class MyModule(nn.Module): + def forward(self, x): + raise NotImplementedError + + @property + def module_str(self): + raise NotImplementedError + + @property + def config(self): + raise NotImplementedError + + @staticmethod + def build_from_config(config): + raise NotImplementedError + + +class MyNetwork(MyModule): + CHANNEL_DIVISIBLE = 8 + + def forward(self, x): + raise NotImplementedError + + @property + def module_str(self): + raise NotImplementedError + + @property + def config(self): + raise NotImplementedError + + @staticmethod + def build_from_config(config): + raise NotImplementedError + + def zero_last_gamma(self): + raise NotImplementedError + + @property + def grouped_block_index(self): + raise NotImplementedError + + """ implemented methods """ + + def set_bn_param(self, momentum, eps, gn_channel_per_group=None, **kwargs): + set_bn_param(self, momentum, eps, gn_channel_per_group, **kwargs) + + def get_bn_param(self): + return get_bn_param(self) + + def get_parameters(self, keys=None, mode="include"): + if keys is None: + for name, param in self.named_parameters(): + if param.requires_grad: + yield param + elif mode == "include": + for name, param in self.named_parameters(): + flag = False + for key in keys: + if key in name: + flag = True + break + if flag and param.requires_grad: + yield param + elif mode == "exclude": + for name, param in self.named_parameters(): + flag = True + for key in keys: + if key in name: + flag = False + break + if flag and param.requires_grad: + yield param + else: + raise ValueError("do not support: %s" % mode) + + def weight_parameters(self): + return self.get_parameters() diff --git a/proard/utils/pytorch_modules.py b/proard/utils/pytorch_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c88c4689301b0254f7edf2c7e96bd6660f8cb39e --- /dev/null +++ b/proard/utils/pytorch_modules.py @@ -0,0 +1,161 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from .my_modules import MyNetwork + +__all__ = [ + "make_divisible", + "build_activation", + "ShuffleLayer", + "MyGlobalAvgPool2d", + "Hswish", + "Hsigmoid", + "SEModule", + "MultiHeadCrossEntropyLoss", +] + + +def make_divisible(v, divisor, min_val=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_val: + :return: + """ + if min_val is None: + min_val = divisor + new_v = max(min_val, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def build_activation(act_func, inplace=True): + if act_func == "relu": + return nn.ReLU(inplace=inplace) + elif act_func == "relu6": + return nn.ReLU6(inplace=inplace) + elif act_func == "tanh": + return nn.Tanh() + elif act_func == "sigmoid": + return nn.Sigmoid() + elif act_func == "h_swish": + return Hswish(inplace=inplace) + elif act_func == "h_sigmoid": + return Hsigmoid(inplace=inplace) + elif act_func is None or act_func == "none": + return None + else: + raise ValueError("do not support: %s" % act_func) + + +class ShuffleLayer(nn.Module): + def __init__(self, groups): + super(ShuffleLayer, self).__init__() + self.groups = groups + + def forward(self, x): + batch_size, num_channels, height, width = x.size() + channels_per_group = num_channels // self.groups + # reshape + x = x.view(batch_size, self.groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + # flatten + x = x.view(batch_size, -1, height, width) + return x + + def __repr__(self): + return "ShuffleLayer(groups=%d)" % self.groups + + +class MyGlobalAvgPool2d(nn.Module): + def __init__(self, keep_dim=True): + super(MyGlobalAvgPool2d, self).__init__() + self.keep_dim = keep_dim + + def forward(self, x): + return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim) + + def __repr__(self): + return "MyGlobalAvgPool2d(keep_dim=%s)" % self.keep_dim + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + def __repr__(self): + return "Hswish()" + + +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + def __repr__(self): + return "Hsigmoid()" + + +class SEModule(nn.Module): + REDUCTION = 4 + + def __init__(self, channel, reduction=None): + super(SEModule, self).__init__() + + self.channel = channel + self.reduction = SEModule.REDUCTION if reduction is None else reduction + + num_mid = make_divisible( + self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE + ) + + self.fc = nn.Sequential( + OrderedDict( + [ + ("reduce", nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)), + ("relu", nn.ReLU(inplace=True)), + ("expand", nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)), + ("h_sigmoid", Hsigmoid(inplace=True)), + ] + ) + ) + + def forward(self, x): + y = x.mean(3, keepdim=True).mean(2, keepdim=True) + y = self.fc(y) + return x * y + + def __repr__(self): + return "SE(channel=%d, reduction=%d)" % (self.channel, self.reduction) + + +class MultiHeadCrossEntropyLoss(nn.Module): + def forward(self, outputs, targets): + assert outputs.dim() == 3, outputs + assert targets.dim() == 2, targets + + assert outputs.size(1) == targets.size(1), (outputs, targets) + num_heads = targets.size(1) + + loss = 0 + for k in range(num_heads): + loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads + return loss diff --git a/proard/utils/pytorch_utils.py b/proard/utils/pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..60b3b7331eefad6c3629bccbfc63d8172c1bcd92 --- /dev/null +++ b/proard/utils/pytorch_utils.py @@ -0,0 +1,237 @@ +# Once for All: Train One Network and Specialize it for Efficient Deployment +# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han +# International Conference on Learning Representations (ICLR), 2020. + +import math +import copy +import time +import torch +import torch.nn as nn + +__all__ = [ + "mix_images", + "mix_labels", + "label_smooth", + "cross_entropy_loss_with_soft_target", + "cross_entropy_with_label_smoothing", + "clean_num_batch_tracked", + "rm_bn_from_net", + "get_net_device", + "count_parameters", + "count_net_flops", + "measure_net_latency", + "get_net_info", + "build_optimizer", + "calc_learning_rate", +] + + +""" Mixup """ + + +def mix_images(images, lam): + flipped_images = torch.flip(images, dims=[0]) # flip along the batch dimension + return lam * images + (1 - lam) * flipped_images + + +def mix_labels(target, lam, n_classes, label_smoothing=0.1): + onehot_target = label_smooth(target, n_classes, label_smoothing) + flipped_target = torch.flip(onehot_target, dims=[0]) + return lam * onehot_target + (1 - lam) * flipped_target + + +""" Label smooth """ + + +def label_smooth(target, n_classes: int, label_smoothing=0.1): + # convert to one-hot + batch_size = target.size(0) + target = torch.unsqueeze(target, 1) + soft_target = torch.zeros((batch_size, n_classes), device=target.device) + soft_target.scatter_(1, target, 1) + # label smoothing + soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes + return soft_target + + +def cross_entropy_loss_with_soft_target(pred, soft_target): + logsoftmax = nn.LogSoftmax() + return torch.mean(torch.sum(-soft_target * logsoftmax(pred), 1)) + + +def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1): + soft_target = label_smooth(target, pred.size(1), label_smoothing) + return cross_entropy_loss_with_soft_target(pred, soft_target) + + +""" BN related """ + + +def clean_num_batch_tracked(net): + for m in net.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + if m.num_batches_tracked is not None: + m.num_batches_tracked.zero_() + + +def rm_bn_from_net(net): + for m in net.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + m.forward = lambda x: x + + +""" Network profiling """ + + +def get_net_device(net): + return net.parameters().__next__().device + + +def count_parameters(net): + total_params = sum(p.numel() for p in net.parameters() if p.requires_grad) + return total_params + + +def count_net_flops(net, data_shape=(1, 3, 32, 32)): + from .flops_counter import profile + + if isinstance(net, nn.DataParallel): + net = net.module + + flop, _ = profile(copy.deepcopy(net), data_shape) + return flop + + +def measure_net_latency( + net, l_type="gpu8", fast=True, input_shape=(3, 32, 32), clean=False +): + if isinstance(net, nn.DataParallel): + net = net.module + + # remove bn from graph + rm_bn_from_net(net) + + # return `ms` + if "gpu" in l_type: + l_type, batch_size = l_type[:3], int(l_type[3:]) + else: + batch_size = 1 + + data_shape = [batch_size] + list(input_shape) + if l_type == "cpu": + if fast: + n_warmup = 5 + n_sample = 10 + else: + n_warmup = 50 + n_sample = 50 + if get_net_device(net) != torch.device("cpu"): + if not clean: + print("move net to cpu for measuring cpu latency") + net = copy.deepcopy(net).cpu() + elif l_type == "gpu": + if fast: + n_warmup = 5 + n_sample = 10 + else: + n_warmup = 50 + n_sample = 50 + else: + raise NotImplementedError + images = torch.zeros(data_shape, device=get_net_device(net)) + + measured_latency = {"warmup": [], "sample": []} + net.eval() + with torch.no_grad(): + for i in range(n_warmup): + inner_start_time = time.time() + net(images) + used_time = (time.time() - inner_start_time) * 1e3 # ms + measured_latency["warmup"].append(used_time) + if not clean: + print("Warmup %d: %.3f" % (i, used_time)) + outer_start_time = time.time() + for i in range(n_sample): + net(images) + total_time = (time.time() - outer_start_time) * 1e3 # ms + measured_latency["sample"].append((total_time, n_sample)) + return total_time / n_sample, measured_latency + + +def get_net_info(net, input_shape=(3, 32, 32), measure_latency=None, print_info=True): + net_info = {} + if isinstance(net, nn.DataParallel): + net = net.module + + # parameters + net_info["params"] = count_parameters(net) / 1e6 + + # flops + net_info["flops"] = count_net_flops(net, [1] + list(input_shape)) / 1e6 + + # latencies + latency_types = [] if measure_latency is None else measure_latency.split("#") + for l_type in latency_types: + latency, measured_latency = measure_net_latency( + net, l_type, fast=False, input_shape=input_shape + ) + net_info["%s latency" % l_type] = {"val": latency, "hist": measured_latency} + + if print_info: + print(net) + print("Total training params: %.2fM" % (net_info["params"])) + print("Total FLOPs: %.2fM" % (net_info["flops"])) + for l_type in latency_types: + print( + "Estimated %s latency: %.3fms" + % (l_type, net_info["%s latency" % l_type]["val"]) + ) + + return net_info + + +""" optimizer """ + + +def build_optimizer( + net_params, opt_type, opt_param, init_lr, weight_decay, no_decay_keys +): + if no_decay_keys is not None: + assert isinstance(net_params, list) and len(net_params) == 2 + net_params = [ + {"params": net_params[0], "weight_decay": weight_decay}, + {"params": net_params[1], "weight_decay": 0}, + ] + else: + net_params = [{"params": net_params, "weight_decay": weight_decay}] + + if opt_type == "sgd": + opt_param = {} if opt_param is None else opt_param + momentum, nesterov = opt_param.get("momentum", 0.9), opt_param.get( + "nesterov", True + ) + optimizer = torch.optim.SGD( + net_params, init_lr, momentum=momentum, nesterov=nesterov + ) + elif opt_type == "adam": + optimizer = torch.optim.Adam(net_params, init_lr) + else: + raise NotImplementedError + return optimizer + + +""" learning rate schedule """ + + +def calc_learning_rate( + epoch, init_lr, n_epochs, batch=0, nBatch=None, lr_schedule_type="cosine" +): + if lr_schedule_type == "cosine": + t_total = n_epochs * nBatch + t_cur = epoch * nBatch + batch + lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total)) + elif lr_schedule_type is None: + lr = init_lr + else: + raise ValueError("do not support: %s" % lr_schedule_type) + return lr