| import numpy as np | |
| import torch | |
| from extra_utils import res_to_list, res_to_seq | |
| class AbScores: | |
| def __init__(self, device = 'cpu', ncpu = 1): | |
| self.device = device | |
| self.ncpu = ncpu | |
| def _initiate_abencoding(self, model, tokenizer): | |
| self.AbLang = model | |
| self.tokenizer = tokenizer | |
| def _encode_sequences(self, seqs): | |
| tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) | |
| with torch.no_grad(): | |
| return self.AbLang.AbRep(tokens).last_hidden_states.numpy() | |
| def _predict_logits(self, seqs): | |
| tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) | |
| with torch.no_grad(): | |
| return self.AbLang(tokens), tokens | |
| def pseudo_log_likelihood(self, seqs, **kwargs): | |
| """ | |
| Pseudo log likelihood of sequences. | |
| """ | |
| plls = [] | |
| for seq in seqs: | |
| labels = self.tokenizer( | |
| seq, pad=True, w_extra_tkns=False, device=self.used_device | |
| ) | |
| idxs = ( | |
| ~torch.isin(labels, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) | |
| ).nonzero() | |
| masked_tokens = labels.repeat(len(idxs), 1) | |
| for num, idx in enumerate(idxs): | |
| masked_tokens[num, idx[1]] = self.tokenizer.mask_token | |
| with torch.no_grad(): | |
| logits = self.AbLang(masked_tokens) | |
| logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") | |
| logits = torch.stack([logits[num, idx[1]] for num, idx in enumerate(idxs)]) | |
| labels = labels[:,idxs[:,1:]].squeeze(2)[0] | |
| nll = torch.nn.functional.cross_entropy( | |
| logits, | |
| labels, | |
| reduction="mean", | |
| ) | |
| pll = -nll | |
| plls.append(pll) | |
| plls = torch.stack(plls, dim=0).cpu().numpy() | |
| return plls | |
| def confidence(self, seqs, **kwargs): | |
| """ | |
| Log likelihood of sequences without masking. | |
| """ | |
| labels = self.tokenizer( | |
| seqs, pad=True, w_extra_tkns=False, device=self.used_device | |
| ) | |
| with torch.no_grad(): | |
| logits = self.AbLang(labels) | |
| logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") | |
| plls = [] | |
| for label, logit in zip(labels, logits): | |
| idxs = ( | |
| ~torch.isin(label, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) | |
| ).nonzero().squeeze(1) | |
| nll = torch.nn.functional.cross_entropy( | |
| logit[idxs], | |
| label[idxs], | |
| reduction="mean", | |
| ) | |
| pll = -nll | |
| plls.append(pll) | |
| return torch.stack(plls, dim=0).cpu().numpy() |