| | import torch |
| |
|
| | import numpy as np |
| | import torch.nn.functional as F |
| |
|
| | from lm_eval.base import BaseLM |
| | from datasets import load_dataset |
| |
|
| |
|
| | def set_seed(seed): |
| | np.random.seed(seed) |
| | torch.random.manual_seed(seed) |
| |
|
| | def get_test_dataset(dataset_name, tokenizer, seqlen=2048): |
| | if dataset_name == "wikitext2": |
| | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') |
| | testdata = "".join(testdata['text']).split('\n') |
| | elif dataset_name == "c4": |
| | testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text'] |
| | else: |
| | raise NotImplementedError |
| | |
| | testdata = [item for item in testdata if item != ""] |
| | tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata] |
| |
|
| | data, doc = [], [tokenizer.bos_token_id] |
| | for sen in tokenized_text: |
| | if len(sen) > seqlen: |
| | continue |
| | if len(doc) + len(sen) > seqlen: |
| | data.append(doc) |
| | doc = [tokenizer.bos_token_id] |
| | doc.extend(sen) |
| | if len(doc) > 1 and len(doc) <= seqlen: |
| | data.append(doc) |
| | return data |
| |
|
| |
|
| | class LMEvalAdaptor(BaseLM): |
| | def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): |
| | super().__init__() |
| |
|
| | assert isinstance(batch_size, int) |
| |
|
| | self.model_name = model_name |
| | self.model = model |
| | self.model.eval() |
| |
|
| | self.tokenizer = tokenizer |
| |
|
| | self.vocab_size = self.tokenizer.vocab_size |
| |
|
| | self._batch_size = batch_size |
| |
|
| | self._max_length = max_length |
| |
|
| | @property |
| | def eot_token_id(self): |
| | |
| | return self.tokenizer.eos_token_id |
| |
|
| | @property |
| | def max_length(self): |
| | if self._max_length != -1: |
| | return self._max_length |
| | if hasattr(self.model.config, "n_ctx"): |
| | return self.model.config.n_ctx |
| | elif hasattr(self.model.config, "max_position_embeddings"): |
| | return self.model.config.max_position_embeddings |
| | elif hasattr(self.model.config, "n_positions"): |
| | return self.model.config.n_positions |
| | elif "bloom" in self.model_name: |
| | return 2048 |
| | elif "llama" in self.model_name: |
| | return 2048 |
| | elif "mpt" in self.model_name: |
| | return 2048 |
| | elif "falcon" in self.model_name: |
| | return 2048 |
| | else: |
| | print(self.model.config) |
| | raise NotImplementedError |
| |
|
| | @property |
| | def max_gen_toks(self): |
| | return 256 |
| |
|
| | @property |
| | def batch_size(self): |
| | return self._batch_size |
| |
|
| | @property |
| | def device(self): |
| | return "cuda" |
| |
|
| | def tok_encode(self, string: str, add_special_tokens=True): |
| | return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) |
| |
|
| | def tok_decode(self, tokens): |
| | return self.tokenizer.decode(tokens) |
| |
|
| | def loglikelihood(self, requests): |
| | new_reqs = [] |
| | for context, continuation in requests: |
| | context, continuation = context.strip(), continuation.strip() |
| | if context == "": |
| | |
| | context_enc = [self.eot_token_id] |
| | else: |
| | context_enc = self.tok_encode(context, add_special_tokens=True) |
| |
|
| | continuation_enc = self.tok_encode(continuation, add_special_tokens=False) |
| |
|
| | new_reqs.append(((context, continuation), context_enc, continuation_enc)) |
| |
|
| | return self._loglikelihood_tokens(new_reqs) |
| |
|
| | def _model_call(self, inps): |
| | """ |
| | inps: a torch tensor of shape [batch, sequence] |
| | the size of sequence may vary from call to call |
| | |
| | returns: a torch tensor of shape [batch, sequence, vocab] with the |
| | logits returned from the model |
| | """ |
| | with torch.no_grad(): |
| | out = self.model(inps)[0] |
| | return out |
| |
|
| | def _model_generate(self, context, max_length, eos_token_id): |
| | return self.model.generate( |
| | context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False |
| | ) |