Spaces:
Sleeping
Sleeping
| from typing import Any | |
| import os | |
| import torch | |
| import ignite.distributed as idist | |
| import torchvision | |
| import torchvision.transforms as T | |
| from torch.utils import data as torch_data | |
| from .classification_wrapper import TopKClassificationWrapper | |
| from torch.utils.data import Subset | |
| from modelguidedattacks.data import get_dataset | |
| from modelguidedattacks.cls_models.accuracy import get_correct_subset_for_models, DATASET_METADATA_DIR | |
| from tqdm import tqdm | |
| def get_gt_labels(dataset: TopKClassificationWrapper, train:bool, dataset_name:str): | |
| training_str = "train" if train else "val" | |
| save_name = os.path.join(DATASET_METADATA_DIR, f"{dataset_name}_labels_{training_str}.p") | |
| if os.path.exists(save_name): | |
| print ("Found labels cache") | |
| return torch.load(save_name) | |
| dataloader = torch_data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4) | |
| gt_labels = [] | |
| for batch in tqdm(dataloader): | |
| gt_labels.extend(batch[1].tolist()) | |
| gt_labels = torch.tensor(gt_labels) | |
| torch.save(gt_labels, save_name) | |
| return gt_labels | |
| def class_balanced_sampling(dataset, gt_labels: torch.Tensor, | |
| correct_labels: list, total_samples=1000): | |
| num_classes = len(dataset.classes) | |
| correct_labels = torch.tensor(correct_labels) | |
| correct_mask = torch.zeros((len(dataset), ), dtype=torch.bool) | |
| correct_mask[correct_labels] = True | |
| sampled_indices = 0 | |
| total_sampled_indices = 0 | |
| sampled_indices = [[] for i in range(num_classes)] | |
| shuffled_inds = torch.randperm(len(dataset)) | |
| for sample_cnt, sample_i in enumerate(shuffled_inds): | |
| if not correct_mask[sample_i]: | |
| continue | |
| sample_class = gt_labels[sample_i] | |
| desired_samples_in_class = (total_sampled_indices // num_classes) + 1 | |
| if len(sampled_indices[sample_class]) < desired_samples_in_class: | |
| sampled_indices[sample_class].append(sample_i.item()) | |
| total_sampled_indices += 1 | |
| if total_sampled_indices >= total_samples: | |
| break | |
| flattened_indices = [] | |
| for class_samples in sampled_indices: | |
| flattened_indices.extend(class_samples) | |
| return torch.tensor(flattened_indices) | |
| def sample_attack_labels(dataset, gt_labels, k, sampler): | |
| """ | |
| dataset: Dataset we're generating attack labels for | |
| gt_labels: List of gt idx for each sample in a dataset | |
| k: attack size | |
| sampler: ["random"] | |
| """ | |
| # Sample from uniform and argsort to simulate | |
| # a batched randperm | |
| attack_label_uniforms = torch.rand((len(gt_labels), len(dataset.classes))) | |
| # We don't want to sample the gt class for any samples | |
| batch_inds = torch.arange(len(gt_labels)) | |
| attack_label_uniforms[batch_inds, gt_labels] = -1. | |
| attack_labels = attack_label_uniforms.argsort(dim=-1, descending=True)[:, :k] | |
| return attack_labels | |
| def setup_data(config: Any, rank): | |
| """Download datasets and create dataloaders | |
| Parameters | |
| ---------- | |
| config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers` | |
| """ | |
| dataset_train, dataset_eval = get_dataset(config.dataset) | |
| train_subset = None | |
| val_subset = None | |
| attack_labels_train = None | |
| attack_labels_val = None | |
| if rank == 0: | |
| gt_labels_train = get_gt_labels(dataset_train, True, config.dataset) | |
| gt_labels_val = get_gt_labels(dataset_eval, False, config.dataset) | |
| attack_labels_train = sample_attack_labels(dataset_train, gt_labels_train, k=config.k, | |
| sampler=config.attack_sampling) | |
| attack_labels_val = sample_attack_labels(dataset_eval, gt_labels_val, k=config.k, | |
| sampler=config.attack_sampling) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| correct_train_set = get_correct_subset_for_models(config.compare_models, | |
| config.dataset, device, | |
| train=True) | |
| correct_eval_set = get_correct_subset_for_models(config.compare_models, | |
| config.dataset, device, | |
| train=False) | |
| # Balanced sampling | |
| train_subset = class_balanced_sampling(dataset_train, gt_labels_train, | |
| correct_train_set) | |
| val_subset = class_balanced_sampling(dataset_eval, gt_labels_val, | |
| correct_eval_set) | |
| if config.overfit: | |
| rand_inds = torch.randperm(len(val_subset))[:16] | |
| train_subset = train_subset[rand_inds] | |
| val_subset = val_subset[rand_inds] | |
| train_subset = idist.broadcast(train_subset, safe_mode=True) | |
| val_subset = idist.broadcast(val_subset, safe_mode=True) | |
| attack_labels_train = idist.broadcast(attack_labels_train, safe_mode=True) | |
| attack_labels_val = idist.broadcast(attack_labels_val, safe_mode=True) | |
| dataset_train = TopKClassificationWrapper(dataset_train, k=config.k, | |
| attack_labels=attack_labels_train) | |
| dataset_eval = TopKClassificationWrapper(dataset_eval, k=config.k, | |
| attack_labels=attack_labels_val) | |
| dataset_train = Subset(dataset_train, train_subset) | |
| dataset_eval = Subset(dataset_eval, val_subset) | |
| # if config.overfit: | |
| # dataset_train = Subset(dataset_train, range(2)) | |
| # dataset_eval = dataset_train | |
| # else: | |
| # dataset_eval = Subset(dataset_eval, torch.randperm(len(dataset_eval))[:1000].tolist() ) | |
| dataloader_train = idist.auto_dataloader( | |
| dataset_train, | |
| batch_size=config.train_batch_size, | |
| shuffle=not config.overfit, | |
| num_workers=config.num_workers, | |
| ) | |
| dataloader_eval = idist.auto_dataloader( | |
| dataset_eval, | |
| batch_size=config.eval_batch_size, | |
| shuffle=True, | |
| num_workers=config.num_workers, | |
| ) | |
| return dataloader_train, dataloader_eval | |