Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import PreTrainedTokenizer | |
| class TokenizedForMCRightPad(Dataset): | |
| def __init__(self, data, tok: PreTrainedTokenizer, prompt_fn): | |
| # data: [query: str, choices: list(str)] | |
| self.tok = tok | |
| self.prompt_fn = prompt_fn | |
| self.max_length = self._find_max_length(data) | |
| self.data = self._build_mc_data(data) | |
| def _find_max_length(self, data): | |
| max_len = 0 | |
| def tok_len(t): | |
| return len(self.tok.encode(t)) | |
| for ex in data: | |
| query = ex["query"] | |
| len_choices = [tok_len(self.prompt_fn(query, c)[1]) for c in ex["choices"]] | |
| max_len = max(max_len, *len_choices) | |
| return max_len | |
| def _build_mc_data(self, data): | |
| processed = [] | |
| num_choices = set(len(e["choices"]) for e in data) | |
| if not len(num_choices) == 1: | |
| raise ValueError(f"Queries have different number of choices, which is not supported! #choices: {num_choices}") | |
| for ex in data: | |
| query, choices = ex["query"], ex["choices"] | |
| processed_input = [self.prompt_fn(query, choice) for choice in choices] | |
| processed_input = [self.tokenize(t_query, t_full) for t_query, t_full in processed_input] | |
| processed.append(processed_input) | |
| return processed | |
| def tokenize_demonstration(self, demonstration): | |
| e = self.tok(demonstration) | |
| return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]) # no padding | |
| def tokenize(self, only_query, full_text): | |
| tok_only_query = self.tok(only_query, add_special_tokens=False) | |
| tok_full_no_padding = self.tok(full_text, add_special_tokens=False) | |
| tok_full = self.tok( | |
| full_text, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| add_special_tokens=False, | |
| ) # <pad> is not a special token | |
| # tok_only_query = self.tok(only_query) | |
| # tok_full_no_padding = self.tok(full_text) | |
| # tok_full = self.tok( | |
| # full_text, | |
| # padding="max_length", | |
| # max_length=self.max_length, | |
| # ) # <pad> is not a special token | |
| # print(f"tok_only_query: {self.tok.convert_ids_to_tokens(tok_only_query.input_ids)}") | |
| # print(f"tok_full_no_padding: {self.tok.convert_ids_to_tokens(tok_full_no_padding.input_ids)}") | |
| # print(f"tok_full: {self.tok.convert_ids_to_tokens(tok_full.input_ids)}") | |
| # exit(0) | |
| len_full = len(tok_full_no_padding.input_ids) | |
| len_query = len(tok_only_query.input_ids) | |
| e = { | |
| "input_ids": tok_full.input_ids, | |
| "attention_mask": tok_full.attention_mask, | |
| "choice_start": len_query, | |
| "choice_end": len_full, | |
| } | |
| # print("Attn:") | |
| # print(tok_full.attention_mask) | |
| # print("input_ids:") | |
| # print(tok_full.input_ids) | |
| dcd_sp = self.tok.convert_ids_to_tokens(tok_full.input_ids, skip_special_tokens=False) | |
| # print(f'{e["choice_start"]}: {e["choice_end"]} = [{self.tok.convert_tokens_to_string(dcd_sp[e["choice_start"] : e["choice_end"]])}]') | |
| return e | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| def _get_one_item(e): | |
| return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]), e["choice_start"], e["choice_end"] | |
| es = self.data[idx] | |
| # num_choices * (input_ids, attn, start_idx, end_idx) | |
| # input_ids, attn: [B, L] | |
| # start_idx, end_idx: [B, ] | |
| return [_get_one_item(e) for e in es] | |