|
|
import logging |
|
|
|
|
|
import timm |
|
|
import torch |
|
|
import torchvision |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
from robustbench.model_zoo.architectures.utils_architectures import normalize_model |
|
|
from robustbench.model_zoo.enums import ThreatModel |
|
|
from robustbench.utils import load_model |
|
|
|
|
|
from typing import Union |
|
|
from models import resnet26 |
|
|
from packaging import version |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_torchvision_model(model_name: str, weight_version: str = "IMAGENET1K_V1"): |
|
|
""" |
|
|
Restore a pre-trained model from torchvision |
|
|
Further details can be found here: https://pytorch.org/vision/0.14/models.html |
|
|
Input: |
|
|
model_name: Name of the model to create and initialize with pre-trained weights |
|
|
weight_version: Name of the pre-trained weights to restore |
|
|
Returns: |
|
|
model: The pre-trained model |
|
|
preprocess: The corresponding input pre-processing |
|
|
""" |
|
|
assert version.parse(torchvision.__version__) >= version.parse("0.13"), "Torchvision version has to be >= 0.13" |
|
|
|
|
|
|
|
|
available_models = torchvision.models.list_models(module=torchvision.models) |
|
|
if model_name not in available_models: |
|
|
raise ValueError(f"Model '{model_name}' is not available in torchvision. Choose from: {available_models}") |
|
|
|
|
|
|
|
|
model_weights = torchvision.models.get_model_weights(model_name) |
|
|
available_weights = [init_name for init_name in dir(model_weights) if "IMAGENET1K" in init_name] |
|
|
|
|
|
|
|
|
if weight_version not in available_weights: |
|
|
raise ValueError(f"Weight type '{weight_version}' is not supported for torchvision model '{model_name}'." |
|
|
f" Choose from: {available_weights}") |
|
|
|
|
|
|
|
|
model_weights = getattr(model_weights, weight_version) |
|
|
|
|
|
|
|
|
model = torchvision.models.get_model(model_name, weights=model_weights) |
|
|
|
|
|
|
|
|
transform = model_weights.transforms() |
|
|
model = normalize_model(model, transform.mean, transform.std) |
|
|
logger.info(f"Successfully restored '{weight_version}' pre-trained weights" |
|
|
f" for model '{model_name}' from torchvision!") |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([transforms.Resize(transform.resize_size, interpolation=transform.interpolation), |
|
|
transforms.CenterCrop(transform.crop_size), |
|
|
transforms.ToTensor()]) |
|
|
return model, preprocess |
|
|
|
|
|
|
|
|
def get_timm_model(model_name: str): |
|
|
""" |
|
|
Restore a pre-trained model from timm: https://github.com/huggingface/pytorch-image-models/tree/main/timm |
|
|
Quickstart: https://huggingface.co/docs/timm/quickstart |
|
|
Input: |
|
|
model_name: Name of the model to create and initialize with pre-trained weights |
|
|
Returns: |
|
|
model: The pre-trained model |
|
|
preprocess: The corresponding input pre-processing |
|
|
""" |
|
|
|
|
|
available_models = timm.list_models(pretrained=True) |
|
|
if model_name not in available_models: |
|
|
raise ValueError(f"Model '{model_name}' is not available in timm. Choose from: {available_models}") |
|
|
|
|
|
|
|
|
model = timm.create_model(model_name, pretrained=True) |
|
|
logger.info(f"Successfully restored the weights of '{model_name}' from timm.") |
|
|
|
|
|
|
|
|
data_config = timm.data.resolve_model_data_config(model) |
|
|
preprocess = timm.data.create_transform(**data_config) |
|
|
|
|
|
|
|
|
for transf in preprocess.transforms[::-1]: |
|
|
if isinstance(transf, transforms.Normalize): |
|
|
|
|
|
model = normalize_model(model, mean=transf.mean, std=transf.std) |
|
|
preprocess.transforms.remove(transf) |
|
|
break |
|
|
|
|
|
return model, preprocess |
|
|
|
|
|
|
|
|
def get_model(cfg, num_classes: int, device: Union[str, torch.device]): |
|
|
""" |
|
|
Setup the pre-defined model architecture and restore the corresponding pre-trained weights |
|
|
Input: |
|
|
cfg: Configurations |
|
|
num_classes: Number of classes |
|
|
device: The device to put the loaded model |
|
|
Return: |
|
|
model: The pre-trained model |
|
|
preprocess: The corresponding input pre-processing |
|
|
""" |
|
|
preprocess = None |
|
|
|
|
|
try: |
|
|
|
|
|
base_model, preprocess = get_torchvision_model(cfg.MODEL.ARCH, weight_version=cfg.MODEL.WEIGHTS) |
|
|
except ValueError: |
|
|
try: |
|
|
|
|
|
base_model, preprocess = get_timm_model(cfg.MODEL.ARCH) |
|
|
except ValueError: |
|
|
try: |
|
|
|
|
|
if cfg.MODEL.ARCH == "resnet26_gn": |
|
|
base_model = resnet26.build_resnet26() |
|
|
checkpoint = torch.load(cfg.MODEL.CKPT_PATH, map_location="cpu") |
|
|
base_model.load_state_dict(checkpoint['net']) |
|
|
base_model = normalize_model(base_model, resnet26.MEAN, resnet26.STD) |
|
|
else: |
|
|
raise ValueError(f"Model {cfg.MODEL.ARCH} is not supported!") |
|
|
logger.info(f"Successfully restored model '{cfg.MODEL.ARCH}' from: {cfg.MODEL.CKPT_PATH}") |
|
|
except ValueError: |
|
|
|
|
|
if cfg.CORRUPTION.DATASET == 'ccc': |
|
|
dataset_name = 'imagenet' |
|
|
else: |
|
|
dataset_name = cfg.CORRUPTION.DATASET.split("_")[0] |
|
|
base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR, dataset_name, ThreatModel.corruptions) |
|
|
|
|
|
return base_model.to(device), preprocess |
|
|
|
|
|
|