Spaces:
Sleeping
Sleeping
| """ | |
| This is a hacky little attempt using the tools from the trigger creation script to identify a | |
| good set of label strings. The idea is to train a linear classifier over the predict token and | |
| then look at the most similar tokens. | |
| """ | |
| import os.path | |
| import numpy as np | |
| import logging | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from transformers import ( | |
| BertForMaskedLM, RobertaForMaskedLM, XLNetLMHeadModel, GPTNeoForCausalLM #, LlamaForCausalLM | |
| ) | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel | |
| from tqdm import tqdm | |
| from . import augments, utils, model_wrapper | |
| logger = logging.getLogger(__name__) | |
| def get_final_embeddings(model): | |
| if isinstance(model, BertForMaskedLM): | |
| return model.cls.predictions.transform | |
| elif isinstance(model, RobertaForMaskedLM): | |
| return model.lm_head.layer_norm | |
| elif isinstance(model, GPT2LMHeadModel): | |
| return model.transformer.ln_f | |
| elif isinstance(model, GPTNeoForCausalLM): | |
| return model.transformer.ln_f | |
| elif isinstance(model, XLNetLMHeadModel): | |
| return model.transformer.dropout | |
| elif "opt" in model.name_or_path: | |
| return model.model.decoder.final_layer_norm | |
| elif "glm" in model.name_or_path: | |
| return model.glm.transformer.layers[35] | |
| elif "llama" in model.name_or_path: | |
| return model.model.norm | |
| else: | |
| raise NotImplementedError(f'{model} not currently supported') | |
| def get_word_embeddings(model): | |
| if isinstance(model, BertForMaskedLM): | |
| return model.cls.predictions.decoder.weight | |
| elif isinstance(model, RobertaForMaskedLM): | |
| return model.lm_head.decoder.weight | |
| elif isinstance(model, GPT2LMHeadModel): | |
| return model.lm_head.weight | |
| elif isinstance(model, GPTNeoForCausalLM): | |
| return model.lm_head.weight | |
| elif isinstance(model, XLNetLMHeadModel): | |
| return model.lm_loss.weight | |
| elif "opt" in model.name_or_path: | |
| return model.lm_head.weight | |
| elif "glm" in model.name_or_path: | |
| return model.glm.transformer.final_layernorm.weight | |
| elif "llama" in model.name_or_path: | |
| return model.lm_head.weight | |
| else: | |
| raise NotImplementedError(f'{model} not currently supported') | |
| def random_prompt(args, tokenizer, device): | |
| prompt = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() | |
| prompt_ids = torch.tensor(prompt, device=device).unsqueeze(0) | |
| return prompt_ids | |
| def topk_search(args, largest=True): | |
| utils.set_seed(args.seed) | |
| device = args.device | |
| logger.info('Loading model, tokenizer, etc.') | |
| config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
| model.to(device) | |
| logger.info('Loading datasets') | |
| collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id) | |
| datasets = utils.load_datasets(args, tokenizer) | |
| train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) | |
| predictor = model_wrapper.ModelWrapper(model, tokenizer) | |
| mask_cnt = torch.zeros([tokenizer.vocab_size]) | |
| phar = tqdm(enumerate(train_loader)) | |
| with torch.no_grad(): | |
| count = 0 | |
| for step, model_inputs in phar: | |
| count += len(model_inputs["input_ids"]) | |
| prompt_ids = random_prompt(args, tokenizer, device) | |
| logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) | |
| _, top = logits.topk(args.k, largest=largest) | |
| ids, frequency = torch.unique(top.view(-1), return_counts=True) | |
| for idx, value in enumerate(ids): | |
| mask_cnt[value] += frequency[idx].detach().cpu() | |
| phar.set_description(f"-> [{step}/{len(train_loader)}] unique:{ids[:5].tolist()}") | |
| if count > 10000: | |
| break | |
| top_cnt, top_ids = mask_cnt.detach().cpu().topk(args.k) | |
| tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist()) | |
| key = "topk" if largest else "lastk" | |
| print(f"-> {key}-{args.k}:{top_ids.tolist()} top_cnt:{top_cnt.tolist()} tokens:{tokens}") | |
| if os.path.exists(args.output): | |
| best_results = torch.load(args.output) | |
| best_results[key] = top_ids | |
| torch.save(best_results, args.output) | |
| class OutputStorage: | |
| """ | |
| This object stores the intermediate gradients of the output a the given PyTorch module, which | |
| otherwise might not be retained. | |
| """ | |
| def __init__(self, module): | |
| self._stored_output = None | |
| module.register_forward_hook(self.hook) | |
| def hook(self, module, input, output): | |
| self._stored_output = output | |
| def get(self): | |
| return self._stored_output | |
| def label_search(args): | |
| device = args.device | |
| utils.set_seed(args.seed) | |
| logger.info('Loading model, tokenizer, etc.') | |
| config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
| model.to(device) | |
| final_embeddings = get_final_embeddings(model) | |
| embedding_storage = OutputStorage(final_embeddings) | |
| word_embeddings = get_word_embeddings(model) | |
| label_map = args.label_map | |
| reverse_label_map = {y: x for x, y in label_map.items()} | |
| # The weights of this projection will help identify the best label words. | |
| projection = torch.nn.Linear(config.hidden_size, len(label_map), dtype=model.dtype) | |
| projection.to(device) | |
| # Obtain the initial trigger tokens and label mapping | |
| if args.prompt: | |
| prompt_ids = tokenizer.encode( | |
| args.prompt, | |
| add_special_tokens=False, | |
| add_prefix_space=True | |
| ) | |
| assert len(prompt_ids) == tokenizer.num_prompt_tokens | |
| else: | |
| if "llama" in args.model_name: | |
| prompt_ids = random_prompt(args, tokenizer, device=args.device).squeeze(0).tolist() | |
| elif "gpt" in args.model_name: | |
| #prompt_ids = [tokenizer.unk_token_id] * tokenizer.num_prompt_tokens | |
| prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist() | |
| elif "opt" in args.model_name: | |
| prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist() | |
| else: | |
| prompt_ids = [tokenizer.mask_token_id] * tokenizer.num_prompt_tokens | |
| prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0) | |
| logger.info('Loading datasets') | |
| collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id) | |
| datasets = utils.load_datasets(args, tokenizer) | |
| train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) | |
| dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=True, collate_fn=collator) | |
| optimizer = torch.optim.SGD(projection.parameters(), lr=args.lr) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| int(args.iters * len(train_loader)), | |
| ) | |
| tot_steps = len(train_loader) | |
| projection.to(word_embeddings.device) | |
| scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) | |
| scores = F.softmax(scores, dim=0) | |
| for i, row in enumerate(scores): | |
| _, top = row.topk(args.k) | |
| decoded = tokenizer.convert_ids_to_tokens(top) | |
| logger.info(f"-> Top k for class {reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}") | |
| best_results = { | |
| "best_acc": 0.0, | |
| "template": args.template, | |
| "model_name": args.model_name, | |
| "dataset_name": args.dataset_name, | |
| "task": args.task | |
| } | |
| logger.info('Training') | |
| for iters in range(args.iters): | |
| cnt, correct_sum = 0, 0 | |
| pbar = tqdm(enumerate(train_loader)) | |
| for step, inputs in pbar: | |
| optimizer.zero_grad() | |
| prompt_mask = inputs.pop('prompt_mask').to(device) | |
| predict_mask = inputs.pop('predict_mask').to(device) | |
| model_inputs = {} | |
| model_inputs["input_ids"] = inputs["input_ids"].clone().to(device) | |
| model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device) | |
| model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask) | |
| with torch.no_grad(): | |
| model(**model_inputs) | |
| embeddings = embedding_storage.get() | |
| predict_mask = predict_mask.to(args.device) | |
| projection = projection.to(args.device) | |
| label = inputs["label"].to(args.device) | |
| if "opt" in args.model_name and False: | |
| predict_embeddings = embeddings[:, 0].view(embeddings.size(0), -1).contiguous() | |
| else: | |
| predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) | |
| logits = projection(predict_embeddings) | |
| loss = F.cross_entropy(logits, label) | |
| pred = logits.argmax(dim=1) | |
| correct = pred.view_as(label).eq(label).sum().detach().cpu() | |
| loss.backward() | |
| if "opt" in args.model_name: | |
| torch.nn.utils.clip_grad_norm_(projection.parameters(), 0.2) | |
| optimizer.step() | |
| scheduler.step() | |
| cnt += len(label) | |
| correct_sum += correct | |
| for param_group in optimizer.param_groups: | |
| current_lr = param_group['lr'] | |
| del inputs | |
| pbar.set_description(f'-> [{iters}/{args.iters}] step:[{step}/{tot_steps}] loss: {loss : 0.4f} acc:{correct/label.shape[0] :0.4f} lr:{current_lr :0.4f}') | |
| train_accuracy = float(correct_sum/cnt) | |
| scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) | |
| scores = F.softmax(scores, dim=0) | |
| best_results["score"] = scores.detach().cpu().numpy() | |
| for i, row in enumerate(scores): | |
| _, top = row.topk(args.k) | |
| decoded = tokenizer.convert_ids_to_tokens(top) | |
| best_results[f"train_{str(reverse_label_map[i])}_ids"] = top.detach().cpu() | |
| best_results[f"train_{str(reverse_label_map[i])}_token"] = ' '.join(decoded) | |
| print(f"-> [{iters}/{args.iters}] Top-k class={reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}") | |
| print() | |
| if iters < 20: | |
| continue | |
| cnt, correct_sum = 0, 0 | |
| pbar = tqdm(dev_loader) | |
| for inputs in pbar: | |
| label = inputs["label"].to(device) | |
| prompt_mask = inputs.pop('prompt_mask').to(device) | |
| predict_mask = inputs.pop('predict_mask').to(device) | |
| model_inputs = {} | |
| model_inputs["input_ids"] = inputs["input_ids"].clone().to(device) | |
| model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device) | |
| model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask) | |
| with torch.no_grad(): | |
| model(**model_inputs) | |
| embeddings = embedding_storage.get() | |
| predict_mask = predict_mask.to(embeddings.device) | |
| projection = projection.to(embeddings.device) | |
| label = label.to(embeddings.device) | |
| predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) | |
| logits = projection(predict_embeddings) | |
| pred = logits.argmax(dim=1) | |
| correct = pred.view_as(label).eq(label).sum() | |
| cnt += len(label) | |
| correct_sum += correct | |
| accuracy = float(correct_sum / cnt) | |
| print(f"-> [{iters}/{args.iters}] train_acc:{train_accuracy:0.4f} test_acc:{accuracy:0.4f}") | |
| if accuracy > best_results["best_acc"]: | |
| best_results["best_acc"] = accuracy | |
| for i, row in enumerate(scores): | |
| best_results[f"best_{str(reverse_label_map[i])}_ids"] = best_results[f"train_{str(reverse_label_map[i])}_ids"] | |
| best_results[f"best_{str(reverse_label_map[i])}_token"] = best_results[f"train_{str(reverse_label_map[i])}_token"] | |
| print() | |
| torch.save(best_results, args.output) | |
| if __name__ == '__main__': | |
| args = augments.get_args() | |
| if args.debug: | |
| level = logging.DEBUG | |
| else: | |
| level = logging.INFO | |
| logging.basicConfig(level=level) | |
| label_search(args) | |
| topk_search(args, largest=True) |