File size: 6,108 Bytes
02ba886 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import logging
import timm
import torch
import torchvision
import torchvision.transforms as transforms
from robustbench.model_zoo.architectures.utils_architectures import normalize_model
from robustbench.model_zoo.enums import ThreatModel
from robustbench.utils import load_model
from typing import Union
from models import resnet26
from packaging import version
logger = logging.getLogger(__name__)
def get_torchvision_model(model_name: str, weight_version: str = "IMAGENET1K_V1"):
"""
Restore a pre-trained model from torchvision
Further details can be found here: https://pytorch.org/vision/0.14/models.html
Input:
model_name: Name of the model to create and initialize with pre-trained weights
weight_version: Name of the pre-trained weights to restore
Returns:
model: The pre-trained model
preprocess: The corresponding input pre-processing
"""
assert version.parse(torchvision.__version__) >= version.parse("0.13"), "Torchvision version has to be >= 0.13"
# check if the specified model name is available in torchvision
available_models = torchvision.models.list_models(module=torchvision.models)
if model_name not in available_models:
raise ValueError(f"Model '{model_name}' is not available in torchvision. Choose from: {available_models}")
# get the weight object of the specified model and the available weight initialization names
model_weights = torchvision.models.get_model_weights(model_name)
available_weights = [init_name for init_name in dir(model_weights) if "IMAGENET1K" in init_name]
# check if the specified type of weights is available
if weight_version not in available_weights:
raise ValueError(f"Weight type '{weight_version}' is not supported for torchvision model '{model_name}'."
f" Choose from: {available_weights}")
# restore the specified weights
model_weights = getattr(model_weights, weight_version)
# setup the specified model and initialize it with the specified pre-trained weights
model = torchvision.models.get_model(model_name, weights=model_weights)
# get the transformation and add the input normalization to the model
transform = model_weights.transforms()
model = normalize_model(model, transform.mean, transform.std)
logger.info(f"Successfully restored '{weight_version}' pre-trained weights"
f" for model '{model_name}' from torchvision!")
# create the corresponding input transformation
preprocess = transforms.Compose([transforms.Resize(transform.resize_size, interpolation=transform.interpolation),
transforms.CenterCrop(transform.crop_size),
transforms.ToTensor()])
return model, preprocess
def get_timm_model(model_name: str):
"""
Restore a pre-trained model from timm: https://github.com/huggingface/pytorch-image-models/tree/main/timm
Quickstart: https://huggingface.co/docs/timm/quickstart
Input:
model_name: Name of the model to create and initialize with pre-trained weights
Returns:
model: The pre-trained model
preprocess: The corresponding input pre-processing
"""
# check if the defined model name is supported as pre-trained model
available_models = timm.list_models(pretrained=True)
if model_name not in available_models:
raise ValueError(f"Model '{model_name}' is not available in timm. Choose from: {available_models}")
# setup pre-trained model
model = timm.create_model(model_name, pretrained=True)
logger.info(f"Successfully restored the weights of '{model_name}' from timm.")
# restore the input pre-processing
data_config = timm.data.resolve_model_data_config(model)
preprocess = timm.data.create_transform(**data_config)
# if there is an input normalization, add it to the model and remove it from the input pre-processing
for transf in preprocess.transforms[::-1]:
if isinstance(transf, transforms.Normalize):
# add input normalization to the model
model = normalize_model(model, mean=transf.mean, std=transf.std)
preprocess.transforms.remove(transf)
break
return model, preprocess
def get_model(cfg, num_classes: int, device: Union[str, torch.device]):
"""
Setup the pre-defined model architecture and restore the corresponding pre-trained weights
Input:
cfg: Configurations
num_classes: Number of classes
device: The device to put the loaded model
Return:
model: The pre-trained model
preprocess: The corresponding input pre-processing
"""
preprocess = None
try:
# load model from torchvision
base_model, preprocess = get_torchvision_model(cfg.MODEL.ARCH, weight_version=cfg.MODEL.WEIGHTS)
except ValueError:
try:
# load model from timm
base_model, preprocess = get_timm_model(cfg.MODEL.ARCH)
except ValueError:
try:
# load some custom models
if cfg.MODEL.ARCH == "resnet26_gn":
base_model = resnet26.build_resnet26()
checkpoint = torch.load(cfg.MODEL.CKPT_PATH, map_location="cpu")
base_model.load_state_dict(checkpoint['net'])
base_model = normalize_model(base_model, resnet26.MEAN, resnet26.STD)
else:
raise ValueError(f"Model {cfg.MODEL.ARCH} is not supported!")
logger.info(f"Successfully restored model '{cfg.MODEL.ARCH}' from: {cfg.MODEL.CKPT_PATH}")
except ValueError:
# load model from robustbench
if cfg.CORRUPTION.DATASET == 'ccc':
dataset_name = 'imagenet'
else:
dataset_name = cfg.CORRUPTION.DATASET.split("_")[0]
base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR, dataset_name, ThreatModel.corruptions)
return base_model.to(device), preprocess
|