Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from modelguidedattacks.data import get_dataset | |
| from . import get_model | |
| from .registry import ClsModel | |
| from typing import Optional, List | |
| DATASET_METADATA_DIR = "./dataset_metadata" | |
| def correct_subset_cache_path(dataset_name: str, model_name: str, train: bool): | |
| filename_train_val = "train" if train else "val" | |
| subset_cache_filename = f"{dataset_name}_{model_name}_{filename_train_val}.p" | |
| subset_cache_path = os.path.join(DATASET_METADATA_DIR, subset_cache_filename) | |
| return subset_cache_path | |
| def get_correct_subset(model: Optional[ClsModel]=None, dataset_name: Optional[str]=None, | |
| model_name: Optional[str]=None, train=True, batch_size=256, | |
| force_cache=False, device="cuda"): | |
| """ | |
| model: Model to evaluate | |
| dataset_name: Name of dataset (not needed if model is provided) | |
| model_name: Name of model (not needed if model is provided) | |
| train: Use training dataset | |
| batch_size: Batch size to use while evaluating | |
| force_cache: Only read from cache and fail if not available | |
| Returns indices in dataset of correctly classified items | |
| """ | |
| if model is not None: | |
| assert dataset_name is None | |
| assert model_name is None | |
| if dataset_name is not None or model_name is not None: | |
| assert dataset_name is not None | |
| assert model_name is not None | |
| assert model is None | |
| if dataset_name is None: | |
| dataset_name = model.dataset_name | |
| if model_name is None: | |
| model_name = model.model_name | |
| filename_train_val = "train" if train else "val" | |
| subset_cache_filename = f"{dataset_name}_{model_name}_{filename_train_val}.p" | |
| subset_cache_path = os.path.join(DATASET_METADATA_DIR, subset_cache_filename) | |
| os.makedirs(DATASET_METADATA_DIR, exist_ok=True) | |
| if os.path.exists(subset_cache_path): | |
| correct_subset = torch.load(subset_cache_path) | |
| return correct_subset | |
| if force_cache: | |
| raise Exception("Cache not found and requested for cached correct subset.") | |
| logging.info(f"No cache found. Computing correct subset for {dataset_name}-{model_name} Train: {train}") | |
| device = device if model is None else model.device | |
| if model is None: | |
| model = get_model(dataset_name, model_name, device) | |
| model.eval() | |
| train_dataset, val_dataset = get_dataset(dataset_name) | |
| dataset = train_dataset | |
| if not train: | |
| dataset = val_dataset | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| correct_indices = [] | |
| for batch_i, (batch_imgs, batch_gt_class) in tqdm(enumerate(dataloader), total=len(dataloader)): | |
| if torch.device(model.device).type.startswith("cuda"): | |
| torch.cuda.synchronize(model.device) | |
| data_start_index = batch_i * batch_size | |
| predictions = model(batch_imgs.to(model.device)) # [B, C] | |
| prediction_class_idx = predictions.argmax(dim=-1) # [B] (long) | |
| prediction_correct = prediction_class_idx == batch_gt_class.to(model.device) | |
| batch_correct_idxs = data_start_index + prediction_correct.nonzero()[:, 0] | |
| batch_correct_idxs = batch_correct_idxs.tolist() | |
| correct_indices.extend(batch_correct_idxs) | |
| correct_subset = set(correct_indices) | |
| torch.save(correct_subset, subset_cache_path) | |
| return set(correct_indices) | |
| def get_correct_subset_for_models(model_names: List[str], dataset_name, device, train): | |
| correct_intersection = None | |
| for model_name in model_names: | |
| model_correct_subset = get_correct_subset(model_name=model_name, dataset_name=dataset_name, | |
| device=device, train=train) | |
| if correct_intersection is None: | |
| correct_intersection = model_correct_subset | |
| else: | |
| correct_intersection = model_correct_subset.intersection(correct_intersection) | |
| return list(correct_intersection) |