Spaces:
Runtime error
Runtime error
| ''' | |
| Author: Qiguang Chen | |
| Date: 2023-01-11 10:39:26 | |
| LastEditors: Qiguang Chen | |
| LastEditTime: 2023-02-19 15:39:48 | |
| Description: all class for load data. | |
| ''' | |
| import os | |
| import torch | |
| import json | |
| from datasets import load_dataset, Dataset | |
| from torch.utils.data import DataLoader | |
| from common.utils import InputData | |
| ABS_PATH=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../") | |
| class DataFactory(object): | |
| def __init__(self, tokenizer,use_multi_intent=False, to_lower_case=True): | |
| """_summary_ | |
| Args: | |
| tokenizer (Tokenizer): _description_ | |
| use_multi_intent (bool, optional): _description_. Defaults to False. | |
| """ | |
| self.tokenizer = tokenizer | |
| self.slot_label_list = [] | |
| self.intent_label_list = [] | |
| self.use_multi = use_multi_intent | |
| self.to_lower_case = to_lower_case | |
| self.slot_label_dict = None | |
| self.intent_label_dict = None | |
| def __is_supported_datasets(self, dataset_name:str)->bool: | |
| return dataset_name.lower() in ["atis", "snips", "mix-atis", "mix-atis"] | |
| def load_dataset(self, dataset_config, split="train"): | |
| dataset_name = None | |
| if split not in dataset_config: | |
| dataset_name = dataset_config.get("dataset_name") | |
| elif self.__is_supported_datasets(dataset_config[split]): | |
| dataset_name = dataset_config[split].lower() | |
| if dataset_name is not None: | |
| return load_dataset("LightChen2333/OpenSLU", dataset_name, split=split) | |
| else: | |
| data_file = dataset_config[split] | |
| data_dict = {"text": [], "slot": [], "intent":[]} | |
| with open(data_file, encoding="utf-8") as f: | |
| for line in f: | |
| row = json.loads(line) | |
| data_dict["text"].append(row["text"]) | |
| data_dict["slot"].append(row["slot"]) | |
| data_dict["intent"].append(row["intent"]) | |
| return Dataset.from_dict(data_dict) | |
| def update_label_names(self, dataset): | |
| for intent_labels in dataset["intent"]: | |
| if self.use_multi: | |
| intent_label = intent_labels.split("#") | |
| else: | |
| intent_label = [intent_labels] | |
| for x in intent_label: | |
| if x not in self.intent_label_list: | |
| self.intent_label_list.append(x) | |
| for slot_label in dataset["slot"]: | |
| for x in slot_label: | |
| if x not in self.slot_label_list: | |
| self.slot_label_list.append(x) | |
| self.intent_label_dict = {key: index for index, | |
| key in enumerate(self.intent_label_list)} | |
| self.slot_label_dict = {key: index for index, | |
| key in enumerate(self.slot_label_list)} | |
| def update_vocabulary(self, dataset): | |
| if self.tokenizer.name_or_path in ["word_tokenizer"]: | |
| for data in dataset: | |
| self.tokenizer.add_instance(data["text"]) | |
| def fast_align_data(text, padding_side="right"): | |
| for i in range(len(text.input_ids)): | |
| desired_output = [] | |
| for word_id in text.word_ids(i): | |
| if word_id is not None: | |
| start, end = text.word_to_tokens( | |
| i, word_id, sequence_index=0 if padding_side == "right" else 1) | |
| if start == end - 1: | |
| tokens = [start] | |
| else: | |
| tokens = [start, end - 1] | |
| if len(desired_output) == 0 or desired_output[-1] != tokens: | |
| desired_output.append(tokens) | |
| yield desired_output | |
| def fast_align(self, | |
| batch, | |
| ignore_index=-100, | |
| device="cuda", | |
| config=None, | |
| enable_label=True, | |
| label2tensor=True): | |
| if self.to_lower_case: | |
| input_list = [[t.lower() for t in x["text"]] for x in batch] | |
| else: | |
| input_list = [x["text"] for x in batch] | |
| text = self.tokenizer(input_list, | |
| return_tensors="pt", | |
| padding=True, | |
| is_split_into_words=True, | |
| truncation=True, | |
| **config).to(device) | |
| if enable_label: | |
| if label2tensor: | |
| slot_mask = torch.ones_like(text.input_ids) * ignore_index | |
| for i, offsets in enumerate( | |
| DataFactory.fast_align_data(text, padding_side=self.tokenizer.padding_side)): | |
| num = 0 | |
| assert len(offsets) == len(batch[i]["text"]) | |
| assert len(offsets) == len(batch[i]["slot"]) | |
| for off in offsets: | |
| slot_mask[i][off[0] | |
| ] = self.slot_label_dict[batch[i]["slot"][num]] | |
| num += 1 | |
| slot = slot_mask.clone() | |
| attentin_id = 0 if self.tokenizer.padding_side == "right" else 1 | |
| for i, slot_batch in enumerate(slot): | |
| for j, x in enumerate(slot_batch): | |
| if x == ignore_index and text.attention_mask[i][j] == attentin_id and (text.input_ids[i][ | |
| j] not in self.tokenizer.all_special_ids or text.input_ids[i][j] == self.tokenizer.unk_token_id): | |
| slot[i][j] = slot[i][j - 1] | |
| slot = slot.to(device) | |
| if not self.use_multi: | |
| intent = torch.tensor( | |
| [self.intent_label_dict[x["intent"]] for x in batch]).to(device) | |
| else: | |
| one_hot = torch.zeros( | |
| (len(batch), len(self.intent_label_list)), dtype=torch.float) | |
| for index, b in enumerate(batch): | |
| for x in b["intent"].split("#"): | |
| one_hot[index][self.intent_label_dict[x]] = 1. | |
| intent = one_hot.to(device) | |
| else: | |
| slot_mask = None | |
| slot = [['#' for _ in range(text.input_ids.shape[1])] | |
| for _ in range(text.input_ids.shape[0])] | |
| for i, offsets in enumerate(DataFactory.fast_align_data(text)): | |
| num = 0 | |
| for off in offsets: | |
| slot[i][off[0]] = batch[i]["slot"][num] | |
| num += 1 | |
| if not self.use_multi: | |
| intent = [x["intent"] for x in batch] | |
| else: | |
| intent = [ | |
| [x for x in b["intent"].split("#")] for b in batch] | |
| return InputData((text, slot, intent)) | |
| else: | |
| return InputData((text, None, None)) | |
| def general_align_data(self, split_text_list, raw_text_list, encoded_text): | |
| for i in range(len(split_text_list)): | |
| desired_output = [] | |
| jdx = 0 | |
| offset = encoded_text.offset_mapping[i].tolist() | |
| split_texts = split_text_list[i] | |
| raw_text = raw_text_list[i] | |
| last = 0 | |
| temp_offset = [] | |
| for off in offset: | |
| s, e = off | |
| if len(temp_offset) > 0 and (e != 0 and last == s): | |
| len_1 = off[1] - off[0] | |
| len_2 = temp_offset[-1][1] - temp_offset[-1][0] | |
| if len_1 > len_2: | |
| temp_offset.pop(-1) | |
| temp_offset.append([0, 0]) | |
| temp_offset.append(off) | |
| continue | |
| temp_offset.append(off) | |
| last = s | |
| offset = temp_offset | |
| for split_text in split_texts: | |
| while jdx < len(offset) and offset[jdx][0] == 0 and offset[jdx][1] == 0: | |
| jdx += 1 | |
| if jdx == len(offset): | |
| continue | |
| start_, end_ = offset[jdx] | |
| tokens = None | |
| if split_text == raw_text[start_:end_].strip(): | |
| tokens = [jdx] | |
| else: | |
| # Compute "xxx" -> "xx" "#x" | |
| temp_jdx = jdx | |
| last_str = raw_text[start_:end_].strip() | |
| while last_str != split_text and temp_jdx < len(offset) - 1: | |
| temp_jdx += 1 | |
| last_str += raw_text[offset[temp_jdx] | |
| [0]:offset[temp_jdx][1]].strip() | |
| if temp_jdx == jdx: | |
| raise ValueError("Illegal Input data") | |
| elif last_str == split_text: | |
| tokens = [jdx, temp_jdx] | |
| jdx = temp_jdx | |
| else: | |
| jdx -= 1 | |
| jdx += 1 | |
| if tokens is not None: | |
| desired_output.append(tokens) | |
| yield desired_output | |
| def general_align(self, | |
| batch, | |
| ignore_index=-100, | |
| device="cuda", | |
| config=None, | |
| enable_label=True, | |
| label2tensor=True, | |
| locale="en-US"): | |
| if self.to_lower_case: | |
| raw_data = [" ".join(x["text"]).lower() if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in | |
| batch] | |
| input_list = [[t.lower() for t in x["text"]] for x in batch] | |
| else: | |
| input_list = [x["text"] for x in batch] | |
| raw_data = [" ".join(x["text"]) if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in | |
| batch] | |
| text = self.tokenizer(raw_data, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| return_offsets_mapping=True, | |
| **config).to(device) | |
| if enable_label: | |
| if label2tensor: | |
| slot_mask = torch.ones_like(text.input_ids) * ignore_index | |
| for i, offsets in enumerate( | |
| self.general_align_data(input_list, raw_data, encoded_text=text)): | |
| num = 0 | |
| # if len(offsets) != len(batch[i]["text"]) or len(offsets) != len(batch[i]["slot"]): | |
| # if | |
| for off in offsets: | |
| slot_mask[i][off[0] | |
| ] = self.slot_label_dict[batch[i]["slot"][num]] | |
| num += 1 | |
| # slot = slot_mask.clone() | |
| # attentin_id = 0 if self.tokenizer.padding_side == "right" else 1 | |
| # for i, slot_batch in enumerate(slot): | |
| # for j, x in enumerate(slot_batch): | |
| # if x == ignore_index and text.attention_mask[i][j] == attentin_id and text.input_ids[i][ | |
| # j] not in self.tokenizer.all_special_ids: | |
| # slot[i][j] = slot[i][j - 1] | |
| slot = slot_mask.to(device) | |
| if not self.use_multi: | |
| intent = torch.tensor( | |
| [self.intent_label_dict[x["intent"]] for x in batch]).to(device) | |
| else: | |
| one_hot = torch.zeros( | |
| (len(batch), len(self.intent_label_list)), dtype=torch.float) | |
| for index, b in enumerate(batch): | |
| for x in b["intent"].split("#"): | |
| one_hot[index][self.intent_label_dict[x]] = 1. | |
| intent = one_hot.to(device) | |
| else: | |
| slot_mask = None | |
| slot = [['#' for _ in range(text.input_ids.shape[1])] | |
| for _ in range(text.input_ids.shape[0])] | |
| for i, offsets in enumerate(self.general_align_data(input_list, raw_data, encoded_text=text)): | |
| num = 0 | |
| for off in offsets: | |
| slot[i][off[0]] = batch[i]["slot"][num] | |
| num += 1 | |
| if not self.use_multi: | |
| intent = [x["intent"] for x in batch] | |
| else: | |
| intent = [ | |
| [x for x in b["intent"].split("#")] for b in batch] | |
| return InputData((text, slot, intent)) | |
| else: | |
| return InputData((text, None, None)) | |
| def batch_fn(self, | |
| batch, | |
| ignore_index=-100, | |
| device="cuda", | |
| config=None, | |
| align_mode="fast", | |
| enable_label=True, | |
| label2tensor=True): | |
| if align_mode == "fast": | |
| # try: | |
| return self.fast_align(batch, | |
| ignore_index=ignore_index, | |
| device=device, | |
| config=config, | |
| enable_label=enable_label, | |
| label2tensor=label2tensor) | |
| # except: | |
| # return self.general_align(batch, | |
| # ignore_index=ignore_index, | |
| # device=device, | |
| # config=config, | |
| # enable_label=enable_label, | |
| # label2tensor=label2tensor) | |
| else: | |
| return self.general_align(batch, | |
| ignore_index=ignore_index, | |
| device=device, | |
| config=config, | |
| enable_label=enable_label, | |
| label2tensor=label2tensor) | |
| def get_data_loader(self, | |
| dataset, | |
| batch_size, | |
| shuffle=False, | |
| device="cuda", | |
| enable_label=True, | |
| align_mode="fast", | |
| label2tensor=True, **config): | |
| data_loader = DataLoader(dataset, | |
| shuffle=shuffle, | |
| batch_size=batch_size, | |
| collate_fn=lambda x: self.batch_fn(x, | |
| device=device, | |
| config=config, | |
| enable_label=enable_label, | |
| align_mode=align_mode, | |
| label2tensor=label2tensor)) | |
| return data_loader | |