Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import optim | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader | |
| import logging | |
| import numpy as np | |
| from tqdm import tqdm | |
| from sklearn.cluster import KMeans | |
| import os | |
| from utils.toolkit import tensor2numpy, accuracy_domain | |
| from models.slinet import SliNet | |
| from utils.lr_scheduler import build_lr_scheduler | |
| from utils.data_manager import DataManager | |
| from eval import compute_predictions | |
| import wandb | |
| class Prompt2Guard: | |
| def __init__(self, args: dict): | |
| # Network and device settings | |
| self.network = SliNet(args) | |
| self.device = args["device"] | |
| self.class_num = self.network.class_num | |
| # Task and class settings | |
| self.cur_task = -1 | |
| self.n_clusters = 5 | |
| self.n_cluster_one = 1 | |
| self.known_classes = 0 | |
| self.total_classes = 0 | |
| # Key settings, different clusters tested | |
| self.all_keys = [] # consider n_clusters image prototypes for each domain | |
| self.all_keys_one_vector = [] # consider 1 image prototype for each domain | |
| self.real_keys_one_vector = [] # only real images considered to build the prototype | |
| self.fake_keys_one_vector = [] # only fake images considered to build the prototype | |
| # Learning parameters | |
| self.EPSILON = args["EPSILON"] | |
| self.init_lr = args["init_lr"] | |
| self.init_lr_decay = args["init_lr_decay"] | |
| self.init_weight_decay = args["init_weight_decay"] | |
| self.epochs = args["epochs"] | |
| self.warmup_epoch = args["warmup_epoch"] | |
| self.lrate = args["lrate"] | |
| self.lrate_decay = args["lrate_decay"] | |
| self.batch_size = args["batch_size"] | |
| self.batch_size_eval = args["batch_size_eval"] | |
| self.weight_decay = args["weight_decay"] | |
| self.label_smoothing = args["label_smoothing"] | |
| self.enable_prev_prompt = args["enable_prev_prompt"] | |
| # System settings | |
| self.num_workers = int( | |
| os.environ.get("SLURM_CPUS_ON_NODE", args["num_workers"]) | |
| ) | |
| self.filename = args["filename"] | |
| # Other settings | |
| self.args = args | |
| # # wandb setup | |
| # slurm_job_name = os.environ.get("SLURM_JOB_NAME", 'prompt2guard') | |
| # if slurm_job_name == "bash": | |
| # slurm_job_name += "/localtest" | |
| # self.wandb_logger = wandb.init( | |
| # project=slurm_job_name.split("/")[0], | |
| # entity="YOUR_USERNAME", | |
| # name=slurm_job_name.split("/")[1], | |
| # mode="disabled" if not args["wandb"] else "online", | |
| # config=args, | |
| # ) | |
| # if self.wandb_logger is None: | |
| # raise ValueError("Failed to initialize wandb logger") | |
| # self.wandb_logger.define_metric("epoch") | |
| # self.wandb_logger.define_metric("task") | |
| # self.wandb_logger.define_metric("condition") | |
| # self.wandb_logger.define_metric("task_*", step_metric="epoch") | |
| # self.wandb_logger.define_metric("eval_trainer/*", step_metric="task") | |
| # self.wandb_logger.define_metric("inference_*", step_metric="condition") | |
| def after_task(self, nb_tasks): | |
| self.known_classes = self.total_classes | |
| if self.enable_prev_prompt and self.network.numtask < nb_tasks: | |
| with torch.no_grad(): | |
| self.network.prompt_learner[self.network.numtask].load_state_dict( | |
| self.network.prompt_learner[self.network.numtask - 1].state_dict() | |
| ) | |
| def incremental_train(self, data_manager: DataManager): | |
| self.cur_task += 1 | |
| self.total_classes = self.known_classes + data_manager.get_task_size( | |
| self.cur_task | |
| ) | |
| self.network.update_fc() | |
| logging.info("Learning on {}-{}".format(self.known_classes, self.total_classes)) | |
| train_dataset = data_manager.get_dataset( | |
| np.arange(self.known_classes, self.total_classes), | |
| source="train", | |
| mode="train", | |
| ) | |
| self.train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=self.num_workers, | |
| ) | |
| test_dataset = data_manager.get_dataset( | |
| np.arange(0, self.total_classes), source="test", mode="test" | |
| ) | |
| self.test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=self.batch_size_eval, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| ) | |
| self._train(self.train_loader, self.test_loader) | |
| self.clustering(self.train_loader) | |
| def _train(self, train_loader, test_loader): | |
| self.network.to(self.device) | |
| for name, param in self.network.named_parameters(): | |
| param.requires_grad_(False) | |
| if "prompt_learner" + "." + str(self.network.numtask - 1) in name: | |
| param.requires_grad_(True) | |
| # Double check | |
| enabled = set() | |
| for name, param in self.network.named_parameters(): | |
| if param.requires_grad: | |
| enabled.add(name) | |
| logging.info(f"Parameters to be updated: {enabled}") | |
| if self.cur_task == 0: | |
| optimizer = optim.SGD( | |
| self.network.parameters(), | |
| momentum=0.9, | |
| lr=self.init_lr, | |
| weight_decay=self.init_weight_decay, | |
| ) | |
| scheduler = build_lr_scheduler( | |
| optimizer, | |
| lr_scheduler="cosine", | |
| warmup_epoch=self.warmup_epoch, | |
| warmup_type="constant", | |
| warmup_cons_lr=1e-5, | |
| max_epoch=self.epochs, | |
| ) | |
| self.run_epoch = self.epochs | |
| self.train_function(train_loader, test_loader, optimizer, scheduler) | |
| else: | |
| optimizer = optim.SGD( | |
| self.network.parameters(), | |
| momentum=0.9, | |
| lr=self.lrate, | |
| weight_decay=self.weight_decay, | |
| ) | |
| scheduler = build_lr_scheduler( | |
| optimizer, | |
| lr_scheduler="cosine", | |
| warmup_epoch=self.warmup_epoch, | |
| warmup_type="constant", | |
| warmup_cons_lr=1e-5, | |
| max_epoch=self.epochs, | |
| ) | |
| self.run_epoch = self.epochs | |
| self.train_function(train_loader, test_loader, optimizer, scheduler) | |
| def train_function(self, train_loader, test_loader, optimizer, scheduler): | |
| prog_bar = tqdm(range(self.run_epoch)) | |
| best_acc = 0.0 # Already present, used for tracking | |
| # --- Added: Define save path and ensure directory exists --- | |
| # Using the same path as your original save_checkpoint method | |
| save_dir = f'./checkpoint/{self.args["run_name"]}/weights/' | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_path = os.path.join(save_dir, 'best.pt') | |
| # --------------------------------------------------------- | |
| for _, epoch in enumerate(prog_bar): | |
| losses = 0.0 | |
| correct, total = 0, 0 | |
| # Set network to train mode | |
| self.network.train() | |
| with tqdm(train_loader, unit='batch', mininterval=10) as tepoch: | |
| tepoch.set_description(f'Epoch {epoch}', refresh=False) | |
| for i, (object_name, inputs, targets) in enumerate(train_loader): | |
| inputs, targets = inputs.to(self.device), targets.to(self.device) | |
| mask = (targets >= self.known_classes).nonzero().view(-1) | |
| inputs = torch.index_select(inputs, 0, mask) | |
| targets = torch.index_select(targets, 0, mask) - self.known_classes | |
| logits = self.network(inputs, object_name)["logits"] | |
| loss = F.cross_entropy( | |
| logits, targets, label_smoothing=self.label_smoothing | |
| ) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| losses += loss.item() | |
| tepoch.set_postfix(loss=loss.item()) | |
| _, preds = torch.max(logits, dim=1) | |
| correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
| total += len(targets) | |
| #if i> 10: | |
| # break | |
| scheduler.step() | |
| train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
| # Set network to eval mode for computing test_acc | |
| self.network.eval() | |
| test_acc = self._compute_accuracy_domain(self.network, test_loader, epoch) | |
| # --- Added: Checkpoint saving logic --- | |
| if test_acc > best_acc: | |
| best_acc = test_acc | |
| logging.info(f"New best accuracy: {best_acc}. Saving checkpoint to {save_path}") | |
| # --- Dynamic values --- | |
| model_state_dict = { | |
| name: param.cpu() # Move to CPU for saving | |
| for name, param in self.network.named_parameters() | |
| if param.requires_grad | |
| } | |
| # --- Hardcoded values --- | |
| # WARNING: The 'all_keys' tensor data was incomplete in the prompt (contained '...'). | |
| # It is NOT included here. Please provide the full tensor to save it. | |
| logging.warning("Checkpoint 'all_keys' is not being saved as the provided data was incomplete.") | |
| # These lists are complete and will be saved. | |
| all_keys_one_cluster_data = [ | |
| 1.9211e-02, -7.6294e-02, 3.2578e-03, 2.5272e-03, -4.6310e-03, | |
| 4.3297e-03, 1.8418e-05, -4.9782e-03, 2.2526e-03, -2.1114e-03, | |
| 1.0109e-02, -2.7512e-02, 9.0561e-03, -1.9394e-02, -1.4694e-02, | |
| 1.3664e-02, 3.0479e-03, -9.8724e-03, -5.1956e-03, -3.5648e-03, | |
| 3.1799e-02, 6.7043e-04, -7.5684e-03, 4.4441e-03, 4.3869e-03, | |
| 5.9395e-03, 1.1765e-02, 5.3444e-03, 3.5152e-03, -4.1580e-03, | |
| 5.9853e-03, -2.6245e-03, -4.7264e-03, -6.5956e-03, -3.2501e-02, | |
| -1.3824e-02, -5.6305e-03, -5.0850e-03, 4.1290e-02, -1.0567e-02, | |
| -2.3212e-03, -1.3599e-03, -9.1782e-03, -1.9608e-03, -5.6496e-03, | |
| -6.1989e-03, -3.9558e-03, 1.1358e-03, 9.8801e-04, 9.9659e-04, | |
| 2.2011e-03, 1.1787e-02, 9.9411e-03, -4.1938e-04, -7.4120e-03, | |
| -3.0609e-02, 1.2871e-02, 2.3331e-02, -3.1972e-04, 1.3802e-02, | |
| 2.4109e-02, -1.4542e-02, -5.2214e-04, 2.3880e-03, -8.7967e-03, | |
| 2.3079e-03, 7.1793e-03, -1.9104e-02, 2.1095e-03, 4.9095e-03, | |
| 1.4954e-02, 3.6407e-02, 3.2745e-02, 7.2365e-03, 1.6739e-02, | |
| 5.6763e-03, 1.4481e-02, 1.3771e-02, 3.7212e-03, -2.2945e-03, | |
| 2.4948e-02, 3.3936e-02, -1.8433e-02, 4.9639e-04, -1.2941e-03, | |
| 8.9417e-03, -1.5440e-03, -9.7427e-03, 7.0152e-03, 4.7827e-04, | |
| 6.1264e-03, -6.9313e-03, 8.3008e-03, 1.0155e-02, -8.6136e-03, | |
| -6.0539e-03, 1.7578e-02, 1.7548e-03, 7.9727e-04, 2.1957e-02, | |
| 5.8098e-03, 1.1665e-02, 1.6342e-02, -5.2185e-02, -8.6746e-03, | |
| -2.7733e-03, -3.7518e-03, -4.6921e-03, -1.7366e-03, -1.6159e-02, | |
| 1.6184e-03, -1.4053e-02, 2.7065e-03, 5.7640e-03, 2.0294e-03, | |
| -1.1093e-02, -2.4395e-03, 7.4310e-03, 6.1760e-03, -9.0103e-03, | |
| -2.3937e-03, -3.1986e-03, 5.9891e-03, 4.6448e-02, 1.3718e-02, | |
| 7.8506e-03, 2.5024e-02, -8.0719e-03, 1.2123e-02, -2.4185e-02, | |
| -1.0757e-02, 1.5686e-02, -5.7144e-03, -3.2291e-03, 3.0075e-02, | |
| 1.0727e-02, -9.3002e-03, 6.6757e-04, -1.4946e-02, 3.4752e-03, | |
| -6.5918e-03, -1.8682e-03, 2.4414e-03, -1.1482e-02, -8.2092e-03, | |
| -9.9564e-03, 1.8387e-02, 5.9547e-03, -1.4580e-02, 1.8509e-02, | |
| -1.7822e-02, -1.3514e-03, -4.4212e-03, 3.4637e-03, 2.6184e-02, | |
| 5.8556e-03, 4.2915e-03, 9.7046e-03, 8.1635e-03, -3.0411e-02, | |
| -1.8127e-02, 1.3885e-02, -1.5060e-02, -3.2471e-02, -4.1656e-03, | |
| -1.1681e-02, -4.8714e-03, -3.3844e-02, -1.7118e-03, -3.9124e-04, | |
| -6.6376e-03, -1.5945e-02, 6.6996e-04, -8.0824e-04, -1.3695e-03, | |
| 1.0586e-03, -9.1400e-03, -1.9836e-03, 3.8757e-02, -9.6588e-03, | |
| 3.4943e-03, 1.1703e-02, 9.9716e-03, 1.3809e-02, 1.5388e-02, | |
| 8.5144e-03, 4.6692e-03, -1.2077e-02, -1.2177e-02, 2.7733e-03, | |
| -8.3351e-04, 1.8988e-03, 9.9869e-03, -6.0997e-03, 3.2349e-02, | |
| 1.2383e-02, -8.7433e-03, 2.2522e-02, 2.7313e-03, 3.1300e-03, | |
| 6.8436e-03, 7.9651e-03, -2.3441e-03, 6.6376e-04, 1.1032e-02, | |
| -9.5367e-06, -2.0218e-03, -6.8169e-03, 1.1269e-02, -1.8620e-04, | |
| -1.4511e-02, -1.2741e-03, -3.3051e-02, 9.3842e-03, 2.8944e-04, | |
| -1.9894e-03, -1.5625e-02, 1.5366e-02, -2.6302e-03, 2.4402e-04, | |
| -1.0735e-02, -1.4359e-02, 7.9269e-03, 3.4866e-03, -1.2794e-02, | |
| -8.2932e-03, 8.8654e-03, 5.0545e-03, 2.0493e-02, -1.1841e-02, | |
| 7.9775e-04, -3.0624e-02, 2.5311e-03, 1.4648e-03, 2.8591e-03, | |
| -3.6602e-03, 9.6054e-03, 2.0790e-03, -1.5549e-02, 2.5501e-03, | |
| -9.0332e-03, -1.6663e-02, 4.6425e-03, -2.8038e-03, -7.9407e-02, | |
| -1.4503e-02, -5.1832e-04, 2.5711e-03, 1.0544e-02, -9.2926e-03, | |
| 1.4709e-02, 8.8806e-03, 9.7046e-03, 1.5163e-03, -2.0691e-02, | |
| 2.5421e-02, -1.9409e-02, 5.8899e-03, -1.1187e-03, -1.8829e-02, | |
| -1.0025e-02, 5.5351e-03, -9.4833e-03, 1.1391e-02, 1.7321e-04, | |
| -1.8509e-02, -9.4681e-03, -1.8234e-03, -8.5678e-03, -1.0094e-02, | |
| -1.4935e-03, 1.9302e-02, -4.7951e-03, 7.8888e-03, -5.0812e-03, | |
| -4.0222e-02, 1.0710e-03, 1.0948e-02, 1.3268e-02, -8.1482e-03, | |
| -2.4673e-02, -7.9041e-03, -7.1602e-03, 1.3466e-02, -5.0964e-03, | |
| -5.4741e-03, -6.1874e-03, -1.7033e-03, -1.1032e-02, -2.7981e-03, | |
| 1.1200e-02, 2.2774e-03, -6.0059e-02, -5.1537e-03, -5.6190e-03, | |
| -3.3474e-04, 8.3780e-04, 6.4026e-02, -1.4801e-03, 1.9436e-03, | |
| -5.7220e-03, 2.7275e-03, 1.1452e-02, -1.4862e-02, -1.1566e-02, | |
| -7.6675e-03, -7.3051e-03, 4.4823e-03, 9.7871e-05, 7.4081e-03, | |
| 1.3952e-03, -3.0613e-03, -2.9812e-03, -1.0757e-03, -1.4320e-02, | |
| -1.4748e-02, 7.1754e-03, 1.9608e-02, 1.2383e-02, -1.3664e-02, | |
| -1.0824e-03, -4.0054e-03, -7.5874e-03, 1.3298e-02, 7.7133e-03, | |
| 9.1019e-03, -2.1118e-02, -1.3878e-02, 6.3591e-03, -2.5921e-03, | |
| 1.8387e-03, -3.8052e-03, -4.7073e-03, 1.8936e-02, -1.6775e-03, | |
| -1.2810e-02, 6.4621e-03, 4.2877e-03, 2.1267e-03, 1.8402e-02, | |
| -1.5030e-02, -1.2848e-02, 5.6549e-02, -2.1172e-03, 1.2917e-02, | |
| 1.6251e-03, 5.6505e-04, 5.9128e-03, -1.7052e-03, -1.7365e-02, | |
| -2.6443e-02, -4.2992e-03, 2.0248e-02, 1.1398e-02, 3.5934e-03, | |
| 4.6082e-03, 2.3232e-03, 7.7820e-03, 1.7023e-03, -1.0612e-02, | |
| 9.8343e-03, -6.1493e-03, 1.0370e-01, -7.5722e-03, 8.9417e-03, | |
| -2.2125e-03, 1.1505e-02, 2.4338e-03, 8.3160e-03, 9.5520e-03, | |
| 4.6501e-03, -3.2253e-03, 1.5726e-03, 1.3916e-02, -4.2297e-02, | |
| 4.5929e-03, 7.0007e-02, -6.9046e-03, 3.8776e-03, -7.3128e-03, | |
| -4.4746e-03, -7.2021e-03, -1.9089e-02, 7.7724e-04, -3.0212e-02, | |
| 2.1301e-02, -3.7327e-03, -1.0414e-02, 4.2610e-03, 1.2299e-02, | |
| 3.3779e-03, 5.6038e-03, 4.7188e-03, -6.9962e-03, 1.0918e-02, | |
| -2.7809e-03, -6.7806e-04, -2.1255e-02, 1.0147e-02, 4.5128e-03, | |
| -1.3494e-04, 1.7227e-02, -3.0422e-03, -1.3802e-02, 1.6754e-02, | |
| 1.7471e-02, 1.4984e-02, 5.0926e-03, -5.0430e-03, 1.5251e-02, | |
| -2.4567e-03, -5.1056e-02, 8.7967e-03, -1.1482e-02, 1.9943e-02, | |
| -8.6021e-04, 1.2939e-02, -7.7972e-03, 7.0152e-03, 1.1497e-02, | |
| -5.8441e-03, -9.1171e-03, 1.2016e-02, -8.6670e-03, 5.2109e-03, | |
| -1.1182e-04, 1.6083e-02, -5.2834e-03, 6.6519e-04, 5.3497e-02, | |
| -2.7603e-02, 1.7090e-02, -1.8097e-02, -5.2452e-03, -3.4256e-03, | |
| 3.5362e-03, -1.4915e-02, -9.9411e-03, 4.7722e-03, 4.2915e-03, | |
| 9.2697e-03, 2.5005e-03, 6.5820e-01, -2.3060e-03, 4.5853e-03, | |
| 1.5092e-04, 4.5357e-03, 1.4420e-02, -6.6910e-03, -1.2039e-02, | |
| -4.7951e-03, 8.5526e-03, -6.4240e-03, -2.4929e-03, -4.5128e-03, | |
| -8.9188e-03, -1.3995e-04, 6.2866e-03, -1.1642e-02, -1.2894e-03, | |
| -5.9280e-03, -6.4621e-03, -5.9662e-03, -2.2858e-02, 2.4551e-02, | |
| 4.2267e-02, -4.8294e-03, 5.3139e-03, -9.0866e-03, -1.0216e-02, | |
| 1.4725e-02, 6.2675e-03, -8.5449e-03, -7.2021e-03, -1.0138e-03, | |
| -6.8665e-03, 5.0545e-03, -3.0422e-03, 4.0588e-03, -4.3144e-03, | |
| -9.7961e-03, -8.7051e-03, 1.7815e-03, -1.6983e-02, -7.6675e-03, | |
| 5.7564e-03, -4.9019e-03, 4.9782e-03, 1.8406e-03, 7.6904e-03, | |
| -1.6876e-02, -1.1360e-02, 6.7177e-03, 6.9351e-03, -3.9673e-03, | |
| 1.1208e-02, 1.4244e-02, 1.0620e-02, 1.0414e-02, -2.9678e-03, | |
| 9.3231e-03, 9.4452e-03, 6.3362e-03, -2.3823e-03, 1.0330e-02, | |
| 1.0872e-02, -3.4924e-03, 1.1650e-02, 1.2863e-02, 7.9651e-03, | |
| 1.3443e-02, 2.6840e-02 | |
| ] | |
| real_keys_one_cluster_data = [ | |
| 1.9150e-02, -6.2927e-02, -8.7070e-04, 1.0548e-03, -4.9629e-03, | |
| 9.4681e-03, -1.7672e-03, 1.7052e-03, 2.1687e-03, -2.4815e-03, | |
| 9.6970e-03, -1.7136e-02, 1.1948e-02, -1.9089e-02, -9.7885e-03, | |
| 9.6817e-03, 5.5923e-03, -6.4087e-03, -4.2458e-03, -2.4815e-03, | |
| 1.8768e-02, -2.4223e-03, -6.3591e-03, 4.7989e-03, 8.4019e-04, | |
| 5.1193e-03, 6.5956e-03, 5.8708e-03, 4.9210e-03, -1.2255e-03, | |
| 8.6136e-03, -3.0861e-03, -2.4738e-03, -8.1558e-03, -3.5278e-02, | |
| -1.2016e-02, -3.6583e-03, -5.0049e-03, 3.2562e-02, -2.2842e-02, | |
| -3.5534e-03, -3.7575e-03, -5.0774e-03, -3.2463e-03, -5.5237e-03, | |
| -3.5343e-03, -2.2774e-03, 1.5135e-03, 3.0479e-03, 2.8191e-03, | |
| -1.6890e-03, 1.4412e-02, 7.8812e-03, -1.0595e-03, -6.8398e-03, | |
| -2.8961e-02, 1.4214e-02, 2.2049e-02, -4.4022e-03, 1.6235e-02, | |
| 2.6199e-02, -1.1688e-02, -1.5574e-03, 1.0359e-04, -5.8594e-03, | |
| 4.7760e-03, 2.7599e-03, -2.2873e-02, 4.3297e-03, 7.3242e-03, | |
| 1.3390e-02, 3.5248e-02, 3.1403e-02, 7.4539e-03, 1.4809e-02, | |
| 7.6141e-03, 1.2939e-02, 1.0178e-02, -4.2038e-03, -3.4580e-03, | |
| 2.1820e-02, 2.8778e-02, -1.9058e-02, 2.6073e-03, 2.0695e-03, | |
| 9.2621e-03, 3.3760e-04, -3.2749e-03, 1.7147e-03, -2.3823e-03, | |
| 6.3591e-03, -4.4136e-03, 9.1476e-03, 7.8278e-03, -5.5618e-03, | |
| -5.2032e-03, 2.0157e-02, 3.4447e-03, 9.6607e-04, 2.1881e-02, | |
| 6.9618e-03, 1.2512e-02, 1.6327e-02, -3.2990e-02, -1.0002e-02, | |
| -2.1763e-03, -7.8344e-04, -1.5841e-03, -3.3512e-03, -1.0925e-02, | |
| -1.2197e-03, -1.3657e-02, 2.4700e-03, 1.0628e-02, 3.4351e-03, | |
| -5.1727e-03, -2.0542e-03, 6.4850e-03, 1.1177e-02, -4.5891e-03, | |
| -2.6035e-03, 6.7902e-04, 3.6545e-03, 4.7119e-02, 1.6006e-02, | |
| 6.6833e-03, 2.0737e-02, -6.5155e-03, 1.2199e-02, -1.9775e-02, | |
| -1.1337e-02, 1.3199e-02, -2.8172e-03, -2.0332e-03, 3.1082e-02, | |
| 8.7891e-03, -1.0460e-02, 7.3586e-03, -1.1574e-02, 4.1161e-03, | |
| -6.6109e-03, -4.0054e-03, 4.1122e-03, -1.2413e-02, -5.4817e-03, | |
| -8.8501e-03, 1.2878e-02, 6.3858e-03, -1.6388e-02, 1.8356e-02, | |
| -2.2537e-02, 2.8992e-03, -5.7297e-03, 3.1681e-03, 2.8961e-02, | |
| 9.9182e-04, 4.6387e-03, 1.4503e-02, 9.0637e-03, -2.8656e-02, | |
| -9.3918e-03, 8.1024e-03, -2.4918e-02, -3.2227e-02, -6.6872e-03, | |
| -6.1913e-03, -3.6316e-03, -3.3295e-02, -5.6877e-03, -2.7008e-03, | |
| -7.5455e-03, -1.7258e-02, 7.3314e-06, -3.4046e-03, -6.4659e-04, | |
| 3.1338e-03, -1.1635e-02, -1.0455e-04, 1.9913e-02, -1.0620e-02, | |
| 4.2458e-03, 7.1678e-03, 1.0223e-02, 8.6517e-03, 8.3771e-03, | |
| 9.2163e-03, 9.4461e-04, -1.3168e-02, -1.2726e-02, 3.5324e-03, | |
| 1.6527e-03, 2.9144e-03, 1.2245e-02, -3.8300e-03, 7.1383e-04, | |
| 1.3206e-02, -7.2937e-03, 2.1286e-02, 3.9368e-03, 1.5991e-02, | |
| 6.0539e-03, 1.1856e-02, -4.9934e-03, -2.5139e-03, 8.6517e-03, | |
| 2.8591e-03, -7.0524e-04, -8.7662e-03, 9.0103e-03, 1.8966e-04, | |
| -1.3924e-02, -1.9150e-03, -2.4231e-02, 5.1956e-03, -4.0321e-03, | |
| -3.4885e-03, -1.6296e-02, 1.3519e-02, -2.2583e-03, -4.9438e-03, | |
| -1.0315e-02, -1.4542e-02, 6.0921e-03, 2.3689e-03, -1.2154e-02, | |
| -7.9575e-03, 3.9444e-03, 9.8572e-03, 2.8687e-02, -6.9313e-03, | |
| 7.6532e-04, -2.5177e-02, -3.7651e-03, -5.1308e-04, 1.7281e-03, | |
| -1.7262e-03, 4.6959e-03, -1.2171e-04, -1.2772e-02, -4.5085e-04, | |
| -9.7752e-04, -1.4389e-02, 4.0970e-03, -5.1804e-03, -6.6589e-02, | |
| -1.6190e-02, 2.8877e-03, 2.2297e-03, 1.0788e-02, -1.0941e-02, | |
| 1.6830e-02, -1.5366e-02, 7.7133e-03, 6.8855e-03, -4.9324e-03, | |
| 2.0111e-02, -8.0795e-03, 6.3591e-03, -5.4216e-04, -1.7212e-02, | |
| -6.2103e-03, 4.3678e-03, -1.0254e-02, 1.1513e-02, 7.4387e-04, | |
| -2.9129e-02, -8.0719e-03, -4.4179e-04, -5.6534e-03, -1.2115e-02, | |
| 1.6153e-05, 1.7136e-02, -6.9427e-03, 8.7738e-03, -4.3182e-03, | |
| -4.7699e-02, 2.4986e-03, 1.0597e-02, 8.9188e-03, -7.6408e-03, | |
| -1.1009e-02, -8.5831e-03, -9.7809e-03, 1.1726e-02, -7.9956e-03, | |
| -6.2294e-03, -6.7978e-03, -1.3418e-03, -1.1559e-02, 6.7472e-04, | |
| 1.0254e-02, -3.4094e-05, -4.2175e-02, -3.6507e-03, -8.2932e-03, | |
| -2.2144e-03, -5.8861e-03, 7.5623e-02, -1.0996e-03, 1.3523e-03, | |
| -3.9978e-03, 3.1223e-03, 8.2321e-03, -1.2772e-02, -9.4070e-03, | |
| -1.2886e-02, -7.2899e-03, 5.7983e-03, 1.7536e-04, 6.3400e-03, | |
| 5.0964e-03, -3.7785e-03, -9.0485e-03, -2.6150e-03, -7.0343e-03, | |
| -1.6571e-02, 4.9896e-03, 1.6342e-02, 1.0910e-02, -5.2986e-03, | |
| 2.3212e-03, -4.4861e-03, -9.3689e-03, 9.6359e-03, 4.9706e-03, | |
| 5.7755e-03, -2.0660e-02, -1.0445e-02, 3.9406e-03, 2.4605e-03, | |
| 6.3515e-04, -2.9392e-03, -6.4850e-03, 1.7822e-02, -6.6071e-03, | |
| -1.2253e-02, 2.3689e-03, 2.0466e-03, 2.9540e-04, 1.7136e-02, | |
| -1.4854e-02, -1.3794e-02, 6.5613e-02, -4.8370e-03, 1.2672e-02, | |
| 2.2087e-03, 9.5367e-04, 3.9291e-03, -2.1000e-03, -1.5427e-02, | |
| -1.8433e-02, -1.7166e-03, 1.5778e-02, 9.9258e-03, 3.7346e-03, | |
| 3.6659e-03, -3.5114e-03, 8.7814e-03, 6.1703e-04, -5.9738e-03, | |
| 6.9847e-03, -6.5155e-03, 1.0339e-01, -9.4986e-03, 6.3477e-03, | |
| -7.8812e-03, 1.2131e-02, 3.6335e-04, 1.0895e-02, 9.9792e-03, | |
| 7.5684e-03, -5.6839e-03, -1.0042e-03, 5.2910e-03, -5.1666e-02, | |
| 7.4844e-03, 6.3110e-02, -8.8120e-03, 5.4264e-04, -1.0300e-02, | |
| -1.5678e-03, -1.3527e-02, -3.0807e-02, 3.4580e-03, -2.7039e-02, | |
| 2.4033e-02, -1.4057e-03, -1.0971e-02, 8.2245e-03, 1.6769e-02, | |
| -2.3613e-03, 3.1643e-03, 4.8714e-03, -4.5013e-03, 9.2163e-03, | |
| -2.3537e-03, -5.1003e-03, -2.0859e-02, 8.7967e-03, 6.5994e-03, | |
| 2.2697e-03, 1.2589e-02, 3.3588e-03, -1.2383e-02, 1.5266e-02, | |
| 1.3687e-02, 6.3972e-03, 1.6413e-03, -4.6806e-03, 1.0757e-02, | |
| -1.6613e-03, -2.3239e-02, 1.1246e-02, -1.0399e-02, 2.2141e-02, | |
| 3.5644e-04, 1.0658e-02, -9.9640e-03, 5.0850e-03, 8.5678e-03, | |
| -7.7820e-03, -7.4501e-03, 1.0712e-02, -9.6359e-03, 3.4695e-03, | |
| 2.2831e-03, 1.3100e-02, -1.3113e-05, -1.5795e-04, 5.4413e-02, | |
| -2.1591e-02, 1.5839e-02, -1.5884e-02, -3.6983e-03, -6.5002e-03, | |
| 3.5877e-03, -1.4893e-02, -6.1798e-03, 5.0468e-03, 6.3210e-03, | |
| 7.8049e-03, -6.3944e-04, 6.4795e-01, 4.6883e-03, 6.0616e-03, | |
| -4.1656e-03, 4.6039e-04, 1.4618e-02, -7.2060e-03, -1.0750e-02, | |
| -3.4237e-03, 9.5749e-03, -7.7934e-03, -4.6539e-03, -2.2488e-03, | |
| -8.2855e-03, 1.1539e-03, 9.4528e-03, -1.1650e-02, -4.3869e-03, | |
| -6.9084e-03, -1.1734e-02, -5.9052e-03, -1.7181e-02, 2.2034e-02, | |
| 3.1860e-02, -1.4830e-03, 1.2236e-03, -1.1803e-02, -9.1858e-03, | |
| 1.4915e-02, 2.6112e-03, -5.1003e-03, -1.0986e-02, 4.1819e-04, | |
| 4.1161e-03, 4.6577e-03, -4.0932e-03, 5.2834e-03, -5.6229e-03, | |
| -6.5880e-03, -1.1993e-02, 1.3895e-03, -1.5312e-02, -4.8790e-03, | |
| 5.4665e-03, -1.0529e-02, 2.9030e-03, 1.9779e-03, 7.1526e-03, | |
| -1.8753e-02, -1.5404e-02, 7.2021e-03, 5.6114e-03, -4.6501e-03, | |
| 6.8207e-03, 1.3756e-02, 9.0027e-03, 1.0193e-02, 2.7943e-04, | |
| 8.9951e-03, 1.1032e-02, 6.6376e-03, -1.1024e-03, 6.4049e-03, | |
| 1.6556e-02, -5.0354e-03, 1.3781e-03, 1.2787e-02, 9.9182e-03, | |
| 1.2466e-02, 2.5681e-02 | |
| ] | |
| fake_keys_one_cluster_data = [ | |
| 1.9287e-02, -8.9661e-02, 7.3853e-03, 3.9978e-03, -4.2992e-03, | |
| -8.0204e-04, 1.8044e-03, -1.1665e-02, 2.3384e-03, -1.7414e-03, | |
| 1.0521e-02, -3.7903e-02, 6.1684e-03, -1.9699e-02, -1.9592e-02, | |
| 1.7639e-02, 5.0545e-04, -1.3336e-02, -6.1455e-03, -4.6463e-03, | |
| 4.4830e-02, 3.7632e-03, -8.7814e-03, 4.0894e-03, 7.9269e-03, | |
| 6.7635e-03, 1.6922e-02, 4.8218e-03, 2.1095e-03, -7.0915e-03, | |
| 3.3550e-03, -2.1648e-03, -6.9771e-03, -5.0354e-03, -2.9755e-02, | |
| -1.5625e-02, -7.6027e-03, -5.1689e-03, 5.0018e-02, 1.7214e-03, | |
| -1.0891e-03, 1.0366e-03, -1.3283e-02, -6.7520e-04, -5.7755e-03, | |
| -8.8654e-03, -5.6343e-03, 7.5817e-04, -1.0710e-03, -8.2541e-04, | |
| 6.0921e-03, 9.1629e-03, 1.2001e-02, 2.2042e-04, -7.9880e-03, | |
| -3.2257e-02, 1.1536e-02, 2.4612e-02, 3.7613e-03, 1.1375e-02, | |
| 2.2018e-02, -1.7395e-02, 5.1260e-04, 4.6730e-03, -1.1734e-02, | |
| -1.5831e-04, 1.1597e-02, -1.5320e-02, -1.1021e-04, 2.4948e-03, | |
| 1.6510e-02, 3.7598e-02, 3.4088e-02, 7.0190e-03, 1.8677e-02, | |
| 3.7403e-03, 1.6022e-02, 1.7365e-02, 1.1642e-02, -1.1311e-03, | |
| 2.8061e-02, 3.9093e-02, -1.7822e-02, -1.6146e-03, -4.6577e-03, | |
| 8.6288e-03, -3.4256e-03, -1.6205e-02, 1.2314e-02, 3.3379e-03, | |
| 5.8899e-03, -9.4528e-03, 7.4501e-03, 1.2482e-02, -1.1665e-02, | |
| -6.9008e-03, 1.5007e-02, 6.5565e-05, 6.2847e-04, 2.2034e-02, | |
| 4.6539e-03, 1.0826e-02, 1.6342e-02, -7.1411e-02, -7.3509e-03, | |
| -3.3722e-03, -6.7177e-03, -7.8049e-03, -1.2106e-04, -2.1393e-02, | |
| 4.4556e-03, -1.4458e-02, 2.9430e-03, 8.9741e-04, 6.2275e-04, | |
| -1.7014e-02, -2.8248e-03, 8.3771e-03, 1.1749e-03, -1.3435e-02, | |
| -2.1858e-03, -7.0763e-03, 8.3237e-03, 4.5807e-02, 1.1436e-02, | |
| 9.0179e-03, 2.9297e-02, -9.6207e-03, 1.2047e-02, -2.8580e-02, | |
| -1.0185e-02, 1.8158e-02, -8.6060e-03, -4.4250e-03, 2.9053e-02, | |
| 1.2657e-02, -8.1406e-03, -6.0234e-03, -1.8311e-02, 2.8343e-03, | |
| -6.5765e-03, 2.7013e-04, 7.6914e-04, -1.0551e-02, -1.0933e-02, | |
| -1.1063e-02, 2.3895e-02, 5.5199e-03, -1.2787e-02, 1.8661e-02, | |
| -1.3115e-02, -5.6038e-03, -3.1128e-03, 3.7594e-03, 2.3392e-02, | |
| 1.0712e-02, 3.9482e-03, 4.8981e-03, 7.2556e-03, -3.2166e-02, | |
| -2.6840e-02, 1.9653e-02, -5.2032e-03, -3.2715e-02, -1.6451e-03, | |
| -1.7166e-02, -6.1111e-03, -3.4363e-02, 2.2621e-03, 1.9178e-03, | |
| -5.7335e-03, -1.4626e-02, 1.3323e-03, 1.7881e-03, -2.0924e-03, | |
| -1.0157e-03, -6.6452e-03, -3.8643e-03, 5.7617e-02, -8.6975e-03, | |
| 2.7409e-03, 1.6235e-02, 9.7198e-03, 1.8967e-02, 2.2400e-02, | |
| 7.8049e-03, 8.3923e-03, -1.0986e-02, -1.1627e-02, 2.0161e-03, | |
| -3.3188e-03, 8.8167e-04, 7.7324e-03, -8.3694e-03, 6.3965e-02, | |
| 1.1551e-02, -1.0185e-02, 2.3758e-02, 1.5268e-03, -9.7275e-03, | |
| 7.6294e-03, 4.0741e-03, 3.0637e-04, 3.8395e-03, 1.3405e-02, | |
| -2.8782e-03, -3.3360e-03, -4.8637e-03, 1.3527e-02, -5.6171e-04, | |
| -1.5106e-02, -6.3324e-04, -4.1840e-02, 1.3580e-02, 4.6082e-03, | |
| -4.8971e-04, -1.4946e-02, 1.7212e-02, -3.0041e-03, 5.4321e-03, | |
| -1.1147e-02, -1.4183e-02, 9.7656e-03, 4.6043e-03, -1.3435e-02, | |
| -8.6288e-03, 1.3786e-02, 2.5606e-04, 1.2306e-02, -1.6754e-02, | |
| 8.2970e-04, -3.6072e-02, 8.8272e-03, 3.4409e-03, 3.9902e-03, | |
| -5.5962e-03, 1.4519e-02, 4.2801e-03, -1.8326e-02, 5.5504e-03, | |
| -1.7090e-02, -1.8936e-02, 5.1880e-03, -4.2844e-04, -9.2224e-02, | |
| -1.2810e-02, -3.9253e-03, 2.9125e-03, 1.0300e-02, -7.6523e-03, | |
| 1.2581e-02, 3.3142e-02, 1.1688e-02, -3.8509e-03, -3.6438e-02, | |
| 3.0731e-02, -3.0746e-02, 5.4169e-03, -1.6947e-03, -2.0447e-02, | |
| -1.3832e-02, 6.7062e-03, -8.7128e-03, 1.1269e-02, -3.9744e-04, | |
| -7.8888e-03, -1.0864e-02, -3.2043e-03, -1.1490e-02, -8.0643e-03, | |
| -3.0022e-03, 2.1454e-02, -2.6512e-03, 7.0000e-03, -5.8441e-03, | |
| -3.2745e-02, -3.5810e-04, 1.1299e-02, 1.7609e-02, -8.6594e-03, | |
| -3.8330e-02, -7.2174e-03, -4.5433e-03, 1.5205e-02, -2.1973e-03, | |
| -4.7188e-03, -5.5771e-03, -2.0638e-03, -1.0506e-02, -6.2714e-03, | |
| 1.2146e-02, 4.5891e-03, -7.7942e-02, -6.6605e-03, -2.9469e-03, | |
| 1.5450e-03, 7.5607e-03, 5.2460e-02, -1.8606e-03, 2.5349e-03, | |
| -7.4501e-03, 2.3327e-03, 1.4671e-02, -1.6953e-02, -1.3733e-02, | |
| -2.4509e-03, -7.3204e-03, 3.1643e-03, 2.0385e-05, 8.4763e-03, | |
| -2.3079e-03, -2.3422e-03, 3.0804e-03, 4.6229e-04, -2.1606e-02, | |
| -1.2924e-02, 9.3613e-03, 2.2888e-02, 1.3863e-02, -2.2034e-02, | |
| -4.4861e-03, -3.5248e-03, -5.8060e-03, 1.6953e-02, 1.0452e-02, | |
| 1.2428e-02, -2.1576e-02, -1.7303e-02, 8.7814e-03, -7.6447e-03, | |
| 3.0422e-03, -4.6730e-03, -2.9335e-03, 2.0065e-02, 3.2501e-03, | |
| -1.3359e-02, 1.0551e-02, 6.5269e-03, 3.9558e-03, 1.9669e-02, | |
| -1.5198e-02, -1.1902e-02, 4.7485e-02, 6.0129e-04, 1.3161e-02, | |
| 1.0414e-03, 1.7655e-04, 7.8964e-03, -1.3103e-03, -1.9287e-02, | |
| -3.4454e-02, -6.8817e-03, 2.4719e-02, 1.2863e-02, 3.4542e-03, | |
| 5.5542e-03, 8.1558e-03, 6.7825e-03, 2.7866e-03, -1.5244e-02, | |
| 1.2680e-02, -5.7831e-03, 1.0400e-01, -5.6496e-03, 1.1536e-02, | |
| 3.4561e-03, 1.0887e-02, 4.5052e-03, 5.7335e-03, 9.1171e-03, | |
| 1.7366e-03, -7.6866e-04, 4.1504e-03, 2.2537e-02, -3.2959e-02, | |
| 1.6994e-03, 7.6843e-02, -4.9973e-03, 7.2136e-03, -4.3221e-03, | |
| -7.3814e-03, -8.7404e-04, -7.3586e-03, -1.9045e-03, -3.3356e-02, | |
| 1.8585e-02, -6.0616e-03, -9.8572e-03, 2.9659e-04, 7.8201e-03, | |
| 9.1171e-03, 8.0490e-03, 4.5662e-03, -9.4910e-03, 1.2611e-02, | |
| -3.2063e-03, 3.7422e-03, -2.1652e-02, 1.1505e-02, 2.4300e-03, | |
| -2.5406e-03, 2.1866e-02, -9.4452e-03, -1.5221e-02, 1.8234e-02, | |
| 2.1271e-02, 2.3575e-02, 8.5373e-03, -5.4016e-03, 1.9745e-02, | |
| -3.2520e-03, -7.8857e-02, 6.3477e-03, -1.2566e-02, 1.7746e-02, | |
| -2.0771e-03, 1.5221e-02, -5.6305e-03, 8.9417e-03, 1.4435e-02, | |
| -3.9024e-03, -1.0788e-02, 1.3313e-02, -7.6981e-03, 6.9542e-03, | |
| -2.5082e-03, 1.9058e-02, -1.0551e-02, 1.4887e-03, 5.2612e-02, | |
| -3.3630e-02, 1.8326e-02, -2.0309e-02, -6.7940e-03, -3.5262e-04, | |
| 3.4847e-03, -1.4931e-02, -1.3710e-02, 4.4937e-03, 2.2621e-03, | |
| 1.0735e-02, 5.6381e-03, 6.6846e-01, -9.3002e-03, 3.1071e-03, | |
| 4.4670e-03, 8.6136e-03, 1.4229e-02, -6.1722e-03, -1.3321e-02, | |
| -6.1684e-03, 7.5264e-03, -5.0545e-03, -3.3212e-04, -6.7787e-03, | |
| -9.5596e-03, -1.4334e-03, 3.1185e-03, -1.1635e-02, 1.8063e-03, | |
| -4.9477e-03, -1.1911e-03, -6.0272e-03, -2.8549e-02, 2.7084e-02, | |
| 5.2704e-02, -8.1787e-03, 9.4070e-03, -6.3782e-03, -1.1246e-02, | |
| 1.4526e-02, 9.9258e-03, -1.1986e-02, -3.4218e-03, -2.4452e-03, | |
| -1.7853e-02, 5.4512e-03, -1.9913e-03, 2.8343e-03, -3.0060e-03, | |
| -1.3000e-02, -5.4169e-03, 2.1744e-03, -1.8661e-02, -1.0452e-02, | |
| 6.0463e-03, 7.2098e-04, 7.0496e-03, 1.7023e-03, 8.2321e-03, | |
| -1.5015e-02, -7.3128e-03, 6.2332e-03, 8.2550e-03, -3.2864e-03, | |
| 1.5602e-02, 1.4740e-02, 1.2245e-02, 1.0635e-02, -6.2141e-03, | |
| 9.6436e-03, 7.8506e-03, 6.0387e-03, -3.6621e-03, 1.4259e-02, | |
| 5.1842e-03, -1.9474e-03, 2.1927e-02, 1.2939e-02, 6.0081e-03, | |
| 1.4420e-02, 2.8000e-02 | |
| ] | |
| # Note: 'cuda:1' device is hardcoded, change if needed | |
| # We use .to(self.device) later to be safe, but keep original info | |
| keys_dict = { | |
| "all_keys": torch.empty(0, dtype=torch.float16), # Placeholder | |
| "all_keys_one_cluster": torch.tensor(all_keys_one_cluster_data, dtype=torch.float16), | |
| "real_keys_one_cluster": torch.tensor(real_keys_one_cluster_data, dtype=torch.float16), | |
| "fake_keys_one_cluster": torch.tensor(fake_keys_one_cluster_data, dtype=torch.float16) | |
| } | |
| # Move all key tensors to CPU for saving | |
| keys_dict_cpu = { | |
| key: tensor.cpu() for key, tensor in keys_dict.items() | |
| } | |
| K_hardcoded = 7 | |
| topk_classes_hardcoded = 5 | |
| ensembling_flags_hardcoded = [False, False, True, False] | |
| # --- Final save_dict --- | |
| save_dict = { | |
| "tasks": self.cur_task, # Dynamic | |
| "model_state_dict": model_state_dict, # Dynamic | |
| "keys": keys_dict_cpu, # Hardcoded (with 'all_keys' missing) | |
| "K": K_hardcoded, # Hardcoded | |
| "topk_classes": topk_classes_hardcoded, # Hardcoded | |
| "ensembling_flags": ensembling_flags_hardcoded, # Hardcoded | |
| "accuracy": best_acc # Dynamic | |
| } | |
| torch.save(save_dict, save_path) | |
| # ---------------------------------------- | |
| info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f} (Best {:.2f})".format( | |
| self.cur_task, | |
| epoch + 1, | |
| self.run_epoch, | |
| losses / len(train_loader), | |
| train_acc, | |
| test_acc, | |
| best_acc, # Added best_acc to info log | |
| ) | |
| prog_bar.set_description(info) | |
| # self.wandb_logger.log( | |
| # { | |
| # "task_{}/train_loss".format(self.cur_task): losses | |
| # / len(train_loader), | |
| # "task_{}/train_acc".format(self.cur_task): train_acc, | |
| # "task_{}/test_acc".format(self.cur_task): test_acc, | |
| # "task_{}/best_test_acc".format(self.cur_task): best_acc, # Log best_acc | |
| # "epoch": epoch + 1, | |
| # } | |
| # ) | |
| logging.info(f"Task {self.cur_task} finished. Best test accuracy: {best_acc}") | |
| # --- Added: Load best model weights after training --- | |
| logging.info(f"Loading best weights from {save_path}") | |
| checkpoint = torch.load(save_path) | |
| # Load the weights back into the network | |
| # Ensure network is on the correct device | |
| self.network.to(self.device) | |
| self.network.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| # ----------------------------------------------------- | |
| def clustering(self, dataloader): | |
| def run_kmeans(n_clusters, fts): | |
| clustering = KMeans( | |
| n_clusters=n_clusters, random_state=0, n_init="auto" | |
| ).fit(fts) | |
| return torch.tensor(clustering.cluster_centers_).to(self.device) | |
| all_fts = [] | |
| real_fts = [] | |
| fake_fts = [] | |
| for _, (_, inputs, targets) in enumerate(dataloader): | |
| inputs, targets = inputs.to(self.device), targets.to(self.device) | |
| index_reals = (targets == self.known_classes).nonzero().view(-1) # 0 real | |
| index_fakes = ((targets == self.known_classes + 1).nonzero().view(-1)) # 1 fake | |
| with torch.no_grad(): | |
| feature = self.network.extract_vector(inputs) # only img fts | |
| all_fts.append(feature) | |
| real_fts.append(torch.index_select(feature, 0, index_reals)) | |
| fake_fts.append(torch.index_select(feature, 0, index_fakes)) | |
| all_fts = torch.cat(all_fts, 0).cpu().detach().numpy() | |
| real_fts = torch.cat(real_fts, 0).cpu().detach().numpy() | |
| fake_fts = torch.cat(fake_fts, 0).cpu().detach().numpy() | |
| self.all_keys.append(run_kmeans(self.n_clusters, all_fts)) | |
| self.all_keys_one_vector.append(run_kmeans(self.n_cluster_one, all_fts)) | |
| self.real_keys_one_vector.append(run_kmeans(self.n_cluster_one, real_fts)) | |
| self.fake_keys_one_vector.append(run_kmeans(self.n_cluster_one, fake_fts)) | |
| def _compute_accuracy_domain(self, model, loader, epoch): | |
| model.eval() | |
| correct, total = 0, 0 | |
| with tqdm(loader, unit='batch', mininterval=10) as tepoch: | |
| tepoch.set_description(f'Validation Epoch {epoch}', refresh=False) | |
| for i, (object_labels, inputs, targets) in enumerate(loader): | |
| #for i, (object_labels, inputs, targets) in enumerate(loader): | |
| inputs = inputs.to(self.device) | |
| with torch.no_grad(): | |
| outputs = model(inputs, object_labels)["logits"] | |
| predicts = torch.max(outputs, dim=1)[1] | |
| correct += ( | |
| (predicts % self.class_num).cpu() == (targets % self.class_num) | |
| ).sum() | |
| total += len(targets) | |
| tepoch.set_postfix(acc=np.around(tensor2numpy(correct) * 100 / total, decimals=2)) | |
| #if i > 10: | |
| # break | |
| return np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
| def save_checkpoint(self): | |
| self.network.cpu() | |
| layers_to_save = ["prompt_learner"] | |
| model_state_dict = { | |
| name: param | |
| for name, param in self.network.named_parameters() | |
| if any(layer in name for layer in layers_to_save) | |
| } | |
| keys_dict = { | |
| "all_keys": torch.stack(self.all_keys).squeeze().to(dtype=torch.float16), | |
| "all_keys_one_cluster": torch.stack(self.all_keys_one_vector) | |
| .squeeze() | |
| .to(dtype=torch.float16), | |
| "real_keys_one_cluster": torch.stack(self.real_keys_one_vector) | |
| .squeeze() | |
| .to(dtype=torch.float16), | |
| "fake_keys_one_cluster": torch.stack(self.fake_keys_one_vector) | |
| .squeeze() | |
| .to(dtype=torch.float16), | |
| } | |
| ensembling_flags = [ | |
| self.network.ensemble_token_embedding, | |
| self.network.ensemble_before_cosine_sim, | |
| self.network.ensemble_after_cosine_sim, | |
| self.network.confidence_score_enable, | |
| ] | |
| save_dict = { | |
| "tasks": self.cur_task, #ok | |
| "model_state_dict": model_state_dict, #ok | |
| "keys": keys_dict, | |
| "K": self.network.K, | |
| #"run_name": os.environ["SLURM_JOB_NAME"], | |
| "topk_classes": self.network.topk_classes, | |
| "ensembling_flags": ensembling_flags, | |
| } | |
| # torch.save(save_dict, "{}_{}.tar".format(self.filename, self.cur_task)) | |
| torch.save(save_dict, f'./checkpoint/{self.args["run_name"]}/weights/best.pt') | |
| def eval_task(self): | |
| y_pred, y_true = self._eval(self.test_loader) | |
| metrics = {} | |
| for logit_key in y_pred.keys(): | |
| metrics[logit_key] = accuracy_domain( | |
| y_pred[logit_key], y_true, self.known_classes, class_num=self.class_num | |
| ) | |
| # self.wandb_logger.log( | |
| # { | |
| # **{ | |
| # f"eval_{logit_key}/{key}": value | |
| # for key, value in metrics[logit_key].items() | |
| # }, | |
| # "task": self.cur_task, | |
| # } | |
| # ) | |
| return metrics | |
| def prepare_tensor(self, tensor, unsqueeze=False): | |
| tensor = torch.stack(tensor).squeeze().to(dtype=torch.float16) | |
| if unsqueeze: | |
| tensor = tensor.unsqueeze(0) | |
| return tensor | |
| def _eval(self, loader): | |
| self.network.eval() | |
| unsqueeze = self.network.numtask == 1 | |
| dummy_key_dict = { | |
| "all_keys": self.prepare_tensor(self.all_keys), | |
| "all_keys_one_cluster": self.prepare_tensor( | |
| self.all_keys_one_vector, unsqueeze | |
| ), | |
| "real_keys_one_cluster": self.prepare_tensor( | |
| self.real_keys_one_vector, unsqueeze | |
| ), | |
| "fake_keys_one_cluster": self.prepare_tensor( | |
| self.fake_keys_one_vector, unsqueeze | |
| ), | |
| "upperbound": self.prepare_tensor(self.fake_keys_one_vector, unsqueeze), | |
| "prototype": "fake", | |
| } | |
| softmax = False | |
| total_tasks = self.network.numtask | |
| y_pred, y_true = {}, [] | |
| for _, (object_name, inputs, targets) in enumerate(loader): | |
| inputs, targets = inputs.to(self.device), targets.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.network.interface(inputs, object_name, total_tasks, dummy_key_dict) # * [B, T, P] | |
| if softmax: | |
| outputs = torch.nn.functional.softmax(outputs, dim=-1) | |
| predicts = compute_predictions(outputs) | |
| for key in predicts.keys(): | |
| if key not in y_pred: | |
| y_pred[key] = [] | |
| y_pred[key].append(predicts[key].cpu().numpy()) | |
| y_true.append(targets.cpu().numpy()) | |
| y_true = np.concatenate(y_true) | |
| for key in y_pred.keys(): | |
| y_pred[key] = np.concatenate(y_pred[key]) | |
| return y_pred, y_true | |