|
|
''' |
|
|
Author: Qiguang Chen |
|
|
Date: 2023-01-11 10:39:26 |
|
|
LastEditors: Qiguang Chen |
|
|
LastEditTime: 2023-02-19 15:39:48 |
|
|
Description: all class for load data. |
|
|
|
|
|
''' |
|
|
import os |
|
|
import torch |
|
|
import json |
|
|
from datasets import load_dataset, Dataset |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from common.utils import InputData |
|
|
|
|
|
ABS_PATH=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../") |
|
|
|
|
|
class DataFactory(object): |
|
|
def __init__(self, tokenizer,use_multi_intent=False, to_lower_case=True): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
tokenizer (Tokenizer): _description_ |
|
|
use_multi_intent (bool, optional): _description_. Defaults to False. |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.slot_label_list = [] |
|
|
self.intent_label_list = [] |
|
|
self.use_multi = use_multi_intent |
|
|
self.to_lower_case = to_lower_case |
|
|
self.slot_label_dict = None |
|
|
self.intent_label_dict = None |
|
|
|
|
|
def __is_supported_datasets(self, dataset_name:str)->bool: |
|
|
return dataset_name.lower() in ["atis", "snips", "mix-atis", "mix-atis"] |
|
|
|
|
|
def load_dataset(self, dataset_config, split="train"): |
|
|
dataset_name = None |
|
|
if split not in dataset_config: |
|
|
dataset_name = dataset_config.get("dataset_name") |
|
|
elif self.__is_supported_datasets(dataset_config[split]): |
|
|
dataset_name = dataset_config[split].lower() |
|
|
if dataset_name is not None: |
|
|
return load_dataset("LightChen2333/OpenSLU", dataset_name, split=split) |
|
|
else: |
|
|
data_file = dataset_config[split] |
|
|
data_dict = {"text": [], "slot": [], "intent":[]} |
|
|
with open(data_file, encoding="utf-8") as f: |
|
|
for line in f: |
|
|
row = json.loads(line) |
|
|
data_dict["text"].append(row["text"]) |
|
|
data_dict["slot"].append(row["slot"]) |
|
|
data_dict["intent"].append(row["intent"]) |
|
|
return Dataset.from_dict(data_dict) |
|
|
|
|
|
def update_label_names(self, dataset): |
|
|
for intent_labels in dataset["intent"]: |
|
|
if self.use_multi: |
|
|
intent_label = intent_labels.split("#") |
|
|
else: |
|
|
intent_label = [intent_labels] |
|
|
for x in intent_label: |
|
|
if x not in self.intent_label_list: |
|
|
self.intent_label_list.append(x) |
|
|
for slot_label in dataset["slot"]: |
|
|
for x in slot_label: |
|
|
if x not in self.slot_label_list: |
|
|
self.slot_label_list.append(x) |
|
|
self.intent_label_dict = {key: index for index, |
|
|
key in enumerate(self.intent_label_list)} |
|
|
self.slot_label_dict = {key: index for index, |
|
|
key in enumerate(self.slot_label_list)} |
|
|
|
|
|
def update_vocabulary(self, dataset): |
|
|
if self.tokenizer.name_or_path in ["word_tokenizer"]: |
|
|
for data in dataset: |
|
|
self.tokenizer.add_instance(data["text"]) |
|
|
|
|
|
@staticmethod |
|
|
def fast_align_data(text, padding_side="right"): |
|
|
for i in range(len(text.input_ids)): |
|
|
desired_output = [] |
|
|
for word_id in text.word_ids(i): |
|
|
if word_id is not None: |
|
|
start, end = text.word_to_tokens( |
|
|
i, word_id, sequence_index=0 if padding_side == "right" else 1) |
|
|
if start == end - 1: |
|
|
tokens = [start] |
|
|
else: |
|
|
tokens = [start, end - 1] |
|
|
if len(desired_output) == 0 or desired_output[-1] != tokens: |
|
|
desired_output.append(tokens) |
|
|
yield desired_output |
|
|
|
|
|
def fast_align(self, |
|
|
batch, |
|
|
ignore_index=-100, |
|
|
device="cuda", |
|
|
config=None, |
|
|
enable_label=True, |
|
|
label2tensor=True): |
|
|
if self.to_lower_case: |
|
|
input_list = [[t.lower() for t in x["text"]] for x in batch] |
|
|
else: |
|
|
input_list = [x["text"] for x in batch] |
|
|
text = self.tokenizer(input_list, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
is_split_into_words=True, |
|
|
truncation=True, |
|
|
**config).to(device) |
|
|
if enable_label: |
|
|
if label2tensor: |
|
|
|
|
|
slot_mask = torch.ones_like(text.input_ids) * ignore_index |
|
|
for i, offsets in enumerate( |
|
|
DataFactory.fast_align_data(text, padding_side=self.tokenizer.padding_side)): |
|
|
num = 0 |
|
|
assert len(offsets) == len(batch[i]["text"]) |
|
|
assert len(offsets) == len(batch[i]["slot"]) |
|
|
for off in offsets: |
|
|
slot_mask[i][off[0] |
|
|
] = self.slot_label_dict[batch[i]["slot"][num]] |
|
|
num += 1 |
|
|
slot = slot_mask.clone() |
|
|
attentin_id = 0 if self.tokenizer.padding_side == "right" else 1 |
|
|
for i, slot_batch in enumerate(slot): |
|
|
for j, x in enumerate(slot_batch): |
|
|
if x == ignore_index and text.attention_mask[i][j] == attentin_id and (text.input_ids[i][ |
|
|
j] not in self.tokenizer.all_special_ids or text.input_ids[i][j] == self.tokenizer.unk_token_id): |
|
|
slot[i][j] = slot[i][j - 1] |
|
|
slot = slot.to(device) |
|
|
if not self.use_multi: |
|
|
intent = torch.tensor( |
|
|
[self.intent_label_dict[x["intent"]] for x in batch]).to(device) |
|
|
else: |
|
|
one_hot = torch.zeros( |
|
|
(len(batch), len(self.intent_label_list)), dtype=torch.float) |
|
|
for index, b in enumerate(batch): |
|
|
for x in b["intent"].split("#"): |
|
|
one_hot[index][self.intent_label_dict[x]] = 1. |
|
|
intent = one_hot.to(device) |
|
|
else: |
|
|
slot_mask = None |
|
|
slot = [['#' for _ in range(text.input_ids.shape[1])] |
|
|
for _ in range(text.input_ids.shape[0])] |
|
|
for i, offsets in enumerate(DataFactory.fast_align_data(text)): |
|
|
num = 0 |
|
|
for off in offsets: |
|
|
slot[i][off[0]] = batch[i]["slot"][num] |
|
|
num += 1 |
|
|
if not self.use_multi: |
|
|
intent = [x["intent"] for x in batch] |
|
|
else: |
|
|
intent = [ |
|
|
[x for x in b["intent"].split("#")] for b in batch] |
|
|
return InputData((text, slot, intent)) |
|
|
else: |
|
|
return InputData((text, None, None)) |
|
|
|
|
|
def general_align_data(self, split_text_list, raw_text_list, encoded_text): |
|
|
for i in range(len(split_text_list)): |
|
|
desired_output = [] |
|
|
jdx = 0 |
|
|
offset = encoded_text.offset_mapping[i].tolist() |
|
|
split_texts = split_text_list[i] |
|
|
raw_text = raw_text_list[i] |
|
|
last = 0 |
|
|
temp_offset = [] |
|
|
for off in offset: |
|
|
s, e = off |
|
|
if len(temp_offset) > 0 and (e != 0 and last == s): |
|
|
len_1 = off[1] - off[0] |
|
|
len_2 = temp_offset[-1][1] - temp_offset[-1][0] |
|
|
if len_1 > len_2: |
|
|
temp_offset.pop(-1) |
|
|
temp_offset.append([0, 0]) |
|
|
temp_offset.append(off) |
|
|
continue |
|
|
temp_offset.append(off) |
|
|
last = s |
|
|
offset = temp_offset |
|
|
for split_text in split_texts: |
|
|
while jdx < len(offset) and offset[jdx][0] == 0 and offset[jdx][1] == 0: |
|
|
jdx += 1 |
|
|
if jdx == len(offset): |
|
|
continue |
|
|
start_, end_ = offset[jdx] |
|
|
tokens = None |
|
|
if split_text == raw_text[start_:end_].strip(): |
|
|
tokens = [jdx] |
|
|
else: |
|
|
|
|
|
temp_jdx = jdx |
|
|
last_str = raw_text[start_:end_].strip() |
|
|
while last_str != split_text and temp_jdx < len(offset) - 1: |
|
|
temp_jdx += 1 |
|
|
last_str += raw_text[offset[temp_jdx] |
|
|
[0]:offset[temp_jdx][1]].strip() |
|
|
|
|
|
if temp_jdx == jdx: |
|
|
raise ValueError("Illegal Input data") |
|
|
elif last_str == split_text: |
|
|
tokens = [jdx, temp_jdx] |
|
|
jdx = temp_jdx |
|
|
else: |
|
|
jdx -= 1 |
|
|
jdx += 1 |
|
|
if tokens is not None: |
|
|
desired_output.append(tokens) |
|
|
yield desired_output |
|
|
|
|
|
def general_align(self, |
|
|
batch, |
|
|
ignore_index=-100, |
|
|
device="cuda", |
|
|
config=None, |
|
|
enable_label=True, |
|
|
label2tensor=True, |
|
|
locale="en-US"): |
|
|
if self.to_lower_case: |
|
|
raw_data = [" ".join(x["text"]).lower() if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in |
|
|
batch] |
|
|
input_list = [[t.lower() for t in x["text"]] for x in batch] |
|
|
else: |
|
|
input_list = [x["text"] for x in batch] |
|
|
raw_data = [" ".join(x["text"]) if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in |
|
|
batch] |
|
|
text = self.tokenizer(raw_data, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_offsets_mapping=True, |
|
|
**config).to(device) |
|
|
if enable_label: |
|
|
if label2tensor: |
|
|
slot_mask = torch.ones_like(text.input_ids) * ignore_index |
|
|
for i, offsets in enumerate( |
|
|
self.general_align_data(input_list, raw_data, encoded_text=text)): |
|
|
num = 0 |
|
|
|
|
|
|
|
|
for off in offsets: |
|
|
slot_mask[i][off[0] |
|
|
] = self.slot_label_dict[batch[i]["slot"][num]] |
|
|
num += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
slot = slot_mask.to(device) |
|
|
if not self.use_multi: |
|
|
intent = torch.tensor( |
|
|
[self.intent_label_dict[x["intent"]] for x in batch]).to(device) |
|
|
else: |
|
|
one_hot = torch.zeros( |
|
|
(len(batch), len(self.intent_label_list)), dtype=torch.float) |
|
|
for index, b in enumerate(batch): |
|
|
for x in b["intent"].split("#"): |
|
|
one_hot[index][self.intent_label_dict[x]] = 1. |
|
|
intent = one_hot.to(device) |
|
|
else: |
|
|
slot_mask = None |
|
|
slot = [['#' for _ in range(text.input_ids.shape[1])] |
|
|
for _ in range(text.input_ids.shape[0])] |
|
|
for i, offsets in enumerate(self.general_align_data(input_list, raw_data, encoded_text=text)): |
|
|
num = 0 |
|
|
for off in offsets: |
|
|
slot[i][off[0]] = batch[i]["slot"][num] |
|
|
num += 1 |
|
|
if not self.use_multi: |
|
|
intent = [x["intent"] for x in batch] |
|
|
else: |
|
|
intent = [ |
|
|
[x for x in b["intent"].split("#")] for b in batch] |
|
|
return InputData((text, slot, intent)) |
|
|
else: |
|
|
return InputData((text, None, None)) |
|
|
|
|
|
def batch_fn(self, |
|
|
batch, |
|
|
ignore_index=-100, |
|
|
device="cuda", |
|
|
config=None, |
|
|
align_mode="fast", |
|
|
enable_label=True, |
|
|
label2tensor=True): |
|
|
if align_mode == "fast": |
|
|
|
|
|
return self.fast_align(batch, |
|
|
ignore_index=ignore_index, |
|
|
device=device, |
|
|
config=config, |
|
|
enable_label=enable_label, |
|
|
label2tensor=label2tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
return self.general_align(batch, |
|
|
ignore_index=ignore_index, |
|
|
device=device, |
|
|
config=config, |
|
|
enable_label=enable_label, |
|
|
label2tensor=label2tensor) |
|
|
|
|
|
def get_data_loader(self, |
|
|
dataset, |
|
|
batch_size, |
|
|
shuffle=False, |
|
|
device="cuda", |
|
|
enable_label=True, |
|
|
align_mode="fast", |
|
|
label2tensor=True, **config): |
|
|
data_loader = DataLoader(dataset, |
|
|
shuffle=shuffle, |
|
|
batch_size=batch_size, |
|
|
collate_fn=lambda x: self.batch_fn(x, |
|
|
device=device, |
|
|
config=config, |
|
|
enable_label=enable_label, |
|
|
align_mode=align_mode, |
|
|
label2tensor=label2tensor)) |
|
|
return data_loader |
|
|
|