import argparse
import dataclasses
import json
import math
import os
import warnings
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Optional, Union
import requests
import torch
from torch import nn
from robustbench.model_zoo import model_dicts as all_models
from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
ACC_FIELDS = {
ThreatModel.corruptions: "corruptions_acc",
ThreatModel.L2: "autoattack_acc",
ThreatModel.Linf: "autoattack_acc"
}
CANNED_USER_AGENT="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36" # NOQA
def download_gdrive(gdrive_id, fname_save):
""" source: https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url """
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, fname_save):
CHUNK_SIZE = 32768
with open(fname_save, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
print('Download started: path={} (gdrive_id={})'.format(
fname_save, gdrive_id))
url_base = "https://docs.google.com/uc?export=download&confirm=t"
session = requests.Session()
# Fix from https://github.com/wkentaro/gdown/pull/294.
session.headers.update(
{"User-Agent": CANNED_USER_AGENT}
)
response = session.get(url_base, params={'id': gdrive_id}, stream=True)
token = get_confirm_token(response)
if token:
params = {'id': gdrive_id, 'confirm': token}
response = session.get(url_base, params=params, stream=True)
save_response_content(response, fname_save)
session.close()
print('Download finished: path={} (gdrive_id={})'.format(
fname_save, gdrive_id))
def rm_substr_from_state_dict(state_dict, substr):
new_state_dict = OrderedDict()
for key in state_dict.keys():
if substr in key: # to delete prefix 'module.' if it exists
new_key = key[len(substr):]
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[key] = state_dict[key]
return new_state_dict
def add_substr_to_state_dict(state_dict, substr):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[substr + k] = v
return new_state_dict
def load_model(model_name: str,
model_dir: Union[str, Path] = './models',
dataset: Union[str,
BenchmarkDataset] = BenchmarkDataset.cifar_10,
threat_model: Union[str, ThreatModel] = ThreatModel.Linf,
norm: Optional[str] = None) -> nn.Module:
"""Loads a model from the model_zoo.
The model is trained on the given ``dataset``, for the given ``threat_model``.
:param model_name: The name used in the model zoo.
:param model_dir: The base directory where the models are saved.
:param dataset: The dataset on which the model is trained.
:param threat_model: The threat model for which the model is trained.
:param norm: Deprecated argument that can be used in place of ``threat_model``. If specified, it
overrides ``threat_model``
:return: A ready-to-used trained model.
"""
dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
if norm is None:
threat_model_: ThreatModel = ThreatModel(threat_model)
else:
threat_model_ = ThreatModel(norm)
warnings.warn(
"`norm` has been deprecated and will be removed in a future version.",
DeprecationWarning)
model_dir_ = Path(model_dir) / dataset_.value / threat_model_.value
model_path = model_dir_ / f'{model_name}.pt'
models = all_models[dataset_][threat_model_]
if not isinstance(models[model_name]['gdrive_id'], list):
model = models[model_name]['model']()
if dataset_ == BenchmarkDataset.imagenet and 'Standard' in model_name:
return model.eval()
if not os.path.exists(model_dir_):
os.makedirs(model_dir_)
if not os.path.isfile(model_path):
download_gdrive(models[model_name]['gdrive_id'], model_path)
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
if 'Kireev2021Effectiveness' in model_name or model_name == 'Andriushchenko2020Understanding':
checkpoint = checkpoint['last'] # we take the last model (choices: 'last', 'best')
try:
# needed for the model of `Carmon2019Unlabeled`
state_dict = rm_substr_from_state_dict(checkpoint['state_dict'],
'module.')
# needed for the model of `Chen2020Efficient`
state_dict = rm_substr_from_state_dict(state_dict,
'model.')
except:
state_dict = rm_substr_from_state_dict(checkpoint, 'module.')
state_dict = rm_substr_from_state_dict(state_dict, 'model.')
if dataset_ == BenchmarkDataset.imagenet:
# so far all models need input normalization, which is added as extra layer
state_dict = add_substr_to_state_dict(state_dict, 'model.')
model = _safe_load_state_dict(model, model_name, state_dict, dataset_)
return model.eval()
# If we have an ensemble of models (e.g., Chen2020Adversarial)
else:
model = models[model_name]['model']()
if not os.path.exists(model_dir_):
os.makedirs(model_dir_)
for i, gid in enumerate(models[model_name]['gdrive_id']):
if not os.path.isfile('{}_m{}.pt'.format(model_path, i)):
download_gdrive(gid, '{}_m{}.pt'.format(model_path, i))
checkpoint = torch.load('{}_m{}.pt'.format(model_path, i),
map_location=torch.device('cpu'))
try:
state_dict = rm_substr_from_state_dict(
checkpoint['state_dict'], 'module.')
except KeyError:
state_dict = rm_substr_from_state_dict(checkpoint, 'module.')
model.models[i] = _safe_load_state_dict(model.models[i],
model_name, state_dict,
dataset_)
model.models[i].eval()
return model.eval()
def _safe_load_state_dict(model: nn.Module, model_name: str,
state_dict: Dict[str, torch.Tensor],
dataset_: BenchmarkDataset) -> nn.Module:
known_failing_models = {
"Andriushchenko2020Understanding", "Augustin2020Adversarial",
"Engstrom2019Robustness", "Pang2020Boosting", "Rice2020Overfitting",
"Rony2019Decoupling", "Wong2020Fast", "Hendrycks2020AugMix_WRN",
"Hendrycks2020AugMix_ResNeXt", "Kireev2021Effectiveness_Gauss50percent",
"Kireev2021Effectiveness_AugMixNoJSD", "Kireev2021Effectiveness_RLAT",
"Kireev2021Effectiveness_RLATAugMixNoJSD", "Kireev2021Effectiveness_RLATAugMixNoJSD",
"Kireev2021Effectiveness_RLATAugMix", "Chen2020Efficient",
"Wu2020Adversarial", "Augustin2020Adversarial_34_10",
"Augustin2020Adversarial_34_10_extra"
}
failure_messages = ['Missing key(s) in state_dict: "mu", "sigma".',
'Unexpected key(s) in state_dict: "model_preact_hl1.1.weight"',
'Missing key(s) in state_dict: "normalize.mean", "normalize.std"']
try:
model.load_state_dict(state_dict, strict=True)
except RuntimeError as e:
if (model_name in known_failing_models or dataset_ == BenchmarkDataset.imagenet
) and any([msg in str(e) for msg in failure_messages]):
model.load_state_dict(state_dict, strict=False)
else:
raise e
return model
def clean_accuracy(model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
batch_size: int = 100,
identifier = None,
device: torch.device = None):
if device is None:
device = x.device
acc = 0.
n_batches = math.ceil(x.shape[0] / batch_size)
with torch.no_grad():
for counter in range(n_batches):
x_curr = x[counter * batch_size:(counter + 1) *
batch_size].to(device)
y_curr = y[counter * batch_size:(counter + 1) *
batch_size].to(device)
output = model(x_curr)
acc += (output.max(1)[1] == y_curr).float().sum()
return acc.item() / x.shape[0]
def list_available_models(
dataset: Union[str, BenchmarkDataset] = BenchmarkDataset.cifar_10,
threat_model: Union[str, ThreatModel] = ThreatModel.Linf,
norm: Optional[str] = None):
dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
if norm is None:
threat_model_: ThreatModel = ThreatModel(threat_model)
else:
threat_model_ = ThreatModel(norm)
warnings.warn(
"`norm` has been deprecated and will be removed in a future version.",
DeprecationWarning)
models = all_models[dataset_][threat_model_].keys()
acc_field = ACC_FIELDS[threat_model_]
json_dicts = []
jsons_dir = Path("./model_info") / dataset_.value / threat_model_.value
for model_name in models:
json_path = jsons_dir / f"{model_name}.json"
# Some models might not yet be in model_info
if not json_path.exists():
continue
with open(json_path, 'r') as model_info:
json_dict = json.load(model_info)
json_dict['model_name'] = model_name
json_dict['venue'] = 'Unpublished' if json_dict[
'venue'] == '' else json_dict['venue']
json_dict[acc_field] = float(json_dict[acc_field]) / 100
json_dict['clean_acc'] = float(json_dict['clean_acc']) / 100
json_dicts.append(json_dict)
json_dicts = sorted(json_dicts, key=lambda d: -d[acc_field])
print('| # | Model ID | Paper | Clean accuracy | Robust accuracy | Architecture | Venue |')
print('|:---:|---|---|:---:|:---:|:---:|:---:|')
for i, json_dict in enumerate(json_dicts):
if json_dict['model_name'] == 'Chen2020Adversarial':
json_dict['architecture'] = json_dict[
'architecture'] + '
(3x ensemble)'
if json_dict['model_name'] != 'Natural':
print(
'| **{}** | **{}** | *[{}]({})* | {:.2%} | {:.2%} | {} | {} |'
.format(i + 1, json_dict['model_name'], json_dict['name'],
json_dict['link'], json_dict['clean_acc'],
json_dict[acc_field], json_dict['architecture'],
json_dict['venue']))
else:
print(
'| **{}** | **{}** | *{}* | {:.2%} | {:.2%} | {} | {} |'
.format(i + 1, json_dict['model_name'], json_dict['name'],
json_dict['clean_acc'], json_dict[acc_field],
json_dict['architecture'], json_dict['venue']))
def _get_bibtex_entry(model_name: str, title: str, authors: str, venue: str, year: int):
authors = authors.replace(', ', ' and ')
return (f"@article{{{model_name},\n"
f"\ttitle\t= {{{title}}},\n"
f"\tauthor\t= {{{authors}}},\n"
f"\tjournal\t= {{{venue}}},\n"
f"\tyear\t= {{{year}}}\n"
"}\n")
def get_leaderboard_bibtex(dataset: Union[str, BenchmarkDataset], threat_model: Union[str, ThreatModel]):
dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
threat_model_: ThreatModel = ThreatModel(threat_model)
jsons_dir = Path("./model_info") / dataset_.value / threat_model_.value
bibtex_entries = set()
for json_path in jsons_dir.glob("*.json"):
model_name = json_path.stem.split("_")[0]
with open(json_path, 'r') as model_info:
model_dict = json.load(model_info)
title = model_dict["name"]
authors = model_dict["authors"]
full_venue = model_dict["venue"]
if full_venue == "N/A":
continue
venue = full_venue.split(" ")[0]
venue = venue.split(",")[0]
year = model_dict["venue"].split(" ")[-1]
bibtex_entry = _get_bibtex_entry(
model_name, title, authors, venue, year)
bibtex_entries.add(bibtex_entry)
str_entries = ''
for entry in bibtex_entries:
print(entry)
str_entries += entry
return bibtex_entries, str_entries
def get_leaderboard_latex(dataset: Union[str, BenchmarkDataset],
threat_model: Union[str, ThreatModel],
l_keys=['clean_acc', 'external', #'autoattack_acc',
'additional_data', 'architecture', 'venue',
'modelzoo_id'],
sort_by='external' #'autoattack_acc'
):
dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
threat_model_: ThreatModel = ThreatModel(threat_model)
models = all_models[dataset_][threat_model_]
print(models.keys())
jsons_dir = Path("./model_info") / dataset_.value / threat_model_.value
entries = []
for json_path in jsons_dir.glob("*.json"):
if not json_path.stem.startswith('Standard'):
model_name = json_path.stem.split("_")[0]
else:
model_name = json_path.stem
with open(json_path, 'r') as model_info:
model_dict = json.load(model_info)
str_curr = '\\citet{{{}}}'.format(model_name) if not model_name in ['Standard', 'Standard_R50'] \
else model_name.replace('_', '\\_')
for k in l_keys:
if k == 'external' and not 'external' in model_dict.keys():
model_dict[k] = model_dict['autoattack_acc']
if k == 'additional_data':
v = 'Y' if model_dict[k] else 'N'
elif k == 'architecture':
v = model_dict[k].replace('WideResNet', 'WRN')
v = v.replace('ResNet', 'RN')
elif k == 'modelzoo_id':
# print(json_path.stem)
v = json_path.stem.split('.json')[0]
if not v in models.keys():
v = 'N/A'
else:
v = v.replace('_', '\\_')
else:
v = model_dict[k]
str_curr += ' & {}'.format(v)
str_curr += '\\\\'
entries.append((str_curr, float(model_dict[sort_by])))
entries = sorted(entries, key=lambda k: k[1], reverse=True)
entries = ['{} &'.format(i + 1) + a for i, (a, b) in enumerate(entries)]
entries = '\n'.join(entries).replace('
', ' ')
return entries
def update_json(dataset: BenchmarkDataset, threat_model: ThreatModel,
model_name: str, accuracy: float, adv_accuracy: float,
eps: Optional[float]) -> None:
json_path = Path(
"model_info"
) / dataset.value / threat_model.value / f"{model_name}.json"
if not json_path.parent.exists():
json_path.parent.mkdir(parents=True, exist_ok=True)
acc_field = ACC_FIELDS[threat_model]
acc_field_kwarg = {acc_field: adv_accuracy}
model_info = ModelInfo(dataset=dataset.value, eps=eps, clean_acc=accuracy, **acc_field_kwarg)
with open(json_path, "w") as f:
f.write(json.dumps(dataclasses.asdict(model_info), indent=2))
@dataclasses.dataclass
class ModelInfo:
link: Optional[str] = None
name: Optional[str] = None
authors: Optional[str] = None
additional_data: Optional[bool] = None
number_forward_passes: Optional[int] = None
dataset: Optional[str] = None
venue: Optional[str] = None
architecture: Optional[str] = None
eps: Optional[float] = None
clean_acc: Optional[float] = None
reported: Optional[float] = None
corruptions_acc: Optional[str] = None
autoattack_acc: Optional[str] = None
footnote: Optional[str] = None
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name',
type=str,
default='Carmon2019Unlabeled')
parser.add_argument('--threat_model',
type=str,
default='Linf',
choices=[x.value for x in ThreatModel])
parser.add_argument('--dataset',
type=str,
default='cifar10',
choices=[x.value for x in BenchmarkDataset])
parser.add_argument('--eps', type=float, default=8 / 255)
parser.add_argument('--n_ex',
type=int,
default=100,
help='number of examples to evaluate on')
parser.add_argument('--batch_size',
type=int,
default=500,
help='batch size for evaluation')
parser.add_argument('--data_dir',
type=str,
default='./data',
help='where to store downloaded datasets')
parser.add_argument('--model_dir',
type=str,
default='./models',
help='where to store downloaded models')
parser.add_argument('--seed',
type=int,
default=0,
help='random seed')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='device to use for computations')
parser.add_argument('--to_disk', type=bool, default=True)
args = parser.parse_args()
return args