| | |
| |
|
| | """Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets""" |
| |
|
| | import torch |
| | import argparse |
| | from nltk import word_tokenize |
| | from tqdm import tqdm |
| | import numpy as np |
| | import json |
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser(description="Preprocessing") |
| |
|
| | parser.add_argument("--func", type=str, default=None, |
| | help="choose to run which function") |
| | parser.add_argument("--raw_file", type=str, default=None, |
| | help="path of the input file") |
| | parser.add_argument("--processed_file", type=str, default=None, |
| | help="path of the output file") |
| | parser.add_argument("--knwl_ref_file", type=str, default=None, |
| | help="path of the knowledge reference file") |
| | parser.add_argument("--resp_ref_file", type=str, default=None, |
| | help="path of the knowledge reference file") |
| | parser.add_argument("--knwl_gen_file", type=str, default=None, |
| | help="path of the generated knowledge file") |
| | parser.add_argument("--test_file", type=str, default=None, |
| | help="path of the test file") |
| | parser.add_argument("--train_file", type=str, default=None, |
| | help="path of the train file") |
| | parser.add_argument("--model_file", type=str, default=None, |
| | help="path of the model file") |
| | parser.add_argument("--data_type", type=str, default=None, |
| | help="data types, choose one out of three types: \ |
| | wow_seen, wow_unseen, and woi") |
| | parser.add_argument("--seed", type=int, default=1234, |
| | help="random seed") |
| |
|
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file): |
| | """ |
| | This is a function used for processing the wizard of wikipedia (wow) dataset |
| | Expected processed format: |
| | topic \t dialogue context \t golden knowledge \t golden response |
| | """ |
| |
|
| | |
| | print("> Loading data from %s" % raw_file) |
| | with open(raw_file, "r") as fr: |
| | dialog_data = json.load(fr) |
| | |
| | print("> Processing data ...") |
| | fproc = open(processed_file, "w") |
| | fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None |
| | fresp = open(resp_ref_file, "w") if resp_ref_file else None |
| | |
| | for i, sample in enumerate(tqdm(dialog_data)): |
| | |
| | dialog = sample["dialog"] |
| | |
| | turn_list = [] |
| | |
| | for j, turn in enumerate(dialog): |
| | |
| | text = turn["text"] |
| | if not (text.endswith("?") or text.endswith(".") or text.endswith("!")): |
| | text = text + "." |
| | |
| | if j == 0: |
| | |
| | turn_list.append(text) |
| | continue |
| |
|
| | speaker = turn["speaker"].lower() |
| | if "wizard" in speaker: |
| | checked_sentence = list(turn["checked_sentence"].values()) |
| | checked_passage = list(turn["checked_passage"].values()) |
| | |
| | assert len(checked_sentence) <= 1 |
| |
|
| | |
| | if len(checked_sentence) > 0: |
| | checked_sentence = checked_sentence[0] |
| | else: |
| | checked_sentence = "no_passages_used" |
| |
|
| | if len(checked_passage) == 1: |
| | checked_passage = checked_passage[0] |
| | else: |
| | checked_passage = "no_passages_used" |
| |
|
| | |
| | if checked_passage != "no_passages_used": |
| | topic = checked_passage |
| | else: |
| | topic = sample["chosen_topic"] |
| | |
| | dialog_context = " [SEP] ".join(turn_list) |
| | knowledge = checked_sentence |
| | response = text |
| | |
| | turn_list.append(response) |
| |
|
| | |
| | fproc.write(topic + "\t" + dialog_context + "\t" + \ |
| | knowledge + "\t" + response + "\n") |
| | |
| | if fknwl: |
| | fknwl.write(knowledge + "\n") |
| | if fresp: |
| | |
| | response = " ".join(word_tokenize(response)) |
| | fresp.write(response + "\n") |
| |
|
| | else: |
| | assert "apprentice" in speaker |
| | turn_list.append(text) |
| |
|
| | fproc.close() |
| | if fknwl: |
| | fknwl.close() |
| | if fresp: |
| | fresp.close() |
| |
|
| |
|
| | def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file): |
| | """ |
| | This is a function used for processing the wizard of internet (woi) dataset |
| | Expected processed format: |
| | topic \t dialogue context \t golden knowledge \t golden response |
| | """ |
| | |
| | print("> Processing %s" % raw_file) |
| | fproc = open(processed_file, "w") |
| | fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None |
| | fresp = open(resp_ref_file, "w") if resp_ref_file else None |
| | |
| | with open(raw_file, "r") as fr: |
| | for i, line in tqdm(enumerate(fr)): |
| | |
| | line = line.strip() |
| | item_dict = json.loads(line) |
| |
|
| | |
| | |
| | item_dict = item_dict.values() |
| | item_dict = list(item_dict)[0] |
| | |
| | |
| | dialog_data = item_dict['dialog_history'] |
| | length = len(dialog_data) |
| | |
| | turn_list = [] |
| | search_text = "" |
| | for i in range(length): |
| | item = dialog_data[i] |
| | action = item['action'] |
| |
|
| | if action == "Wizard => SearchAgent": |
| | search_text = item['text'] |
| |
|
| | elif action == "Wizard => Apprentice": |
| | if len(turn_list) == 0: |
| | |
| | turn = item['text'] |
| | turn_list.append(turn) |
| | continue |
| |
|
| | |
| | contents = item["context"]["contents"] |
| | selects = item["context"]["selected_contents"] |
| | flag = selects[0][0] |
| | selects = selects[1:] |
| | assert len(selects) == len(contents) |
| | |
| | |
| | if flag: |
| | |
| | topic = "no_topic" |
| | knwl_sent = "no_passages_used" |
| | else: |
| | |
| | topic = search_text |
| | |
| | knwl_sent = "" |
| | for content, select in zip(contents, selects): |
| | content = content['content'] |
| | assert len(content) == len(select) |
| | for c, s in zip(content, select): |
| | if s: |
| | knwl_sent = c |
| | break |
| |
|
| | if knwl_sent == "": |
| | |
| | topic = "no_topic" |
| | knwl_sent = "no_passages_used" |
| |
|
| | |
| | dialog_context = " [SEP] ".join(turn_list) |
| | response = item['text'] |
| |
|
| | |
| | topic = topic.replace("\n", "").replace("\r", \ |
| | "").replace("\t", "") |
| | dialog_context = dialog_context.replace("\n", "").replace("\r", \ |
| | "").replace("\t", "") |
| | knwl_sent = knwl_sent.replace("\n", "").replace("\r", \ |
| | "").replace("\t", "") |
| | response = response.replace("\n", "").replace("\r", \ |
| | "").replace("\t", "") |
| | |
| | if topic != "no_topic": |
| | |
| | fproc.write(topic + "\t" + dialog_context + "\t" + \ |
| | knwl_sent + "\t" + response + "\n") |
| | if fknwl: |
| | fknwl.write(knwl_sent + "\n") |
| | if fresp: |
| | |
| | response = " ".join(word_tokenize(response)) |
| | fresp.write(response + "\n") |
| |
|
| | turn_list.append(response) |
| |
|
| | elif action == "Apprentice => Wizard": |
| | turn = item['text'] |
| | turn_list.append(turn) |
| |
|
| | else: |
| | assert action == "SearchAgent => Wizard", \ |
| | "Please check whether you have used the correct data!" |
| |
|
| | fproc.close() |
| | if fknwl: |
| | fknwl.close() |
| | if fresp: |
| | fresp.close() |
| |
|
| |
|
| | def get_database(test_datapath, train_datapath, data_type): |
| | """Get the database by topics""" |
| |
|
| | assert data_type in ["wow_seen", "wow_unseen", "woi"], \ |
| | "Please input a correct data type!!" |
| |
|
| | |
| | print("> reading test data from %s" % test_datapath) |
| | test_topics = {} |
| | with open(test_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | splits = line.split("\t") |
| | topic = splits[0] |
| | test_topics[topic] = True |
| |
|
| | print("> reading data from %s" % train_datapath) |
| | train_data_by_topic = {} |
| | dialog_data_by_topic = {} |
| | dialog_examples = [] |
| | with open(train_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | splits = line.split("\t") |
| | topic = splits[0] |
| | turns = splits[1].split(" [SEP] ")[-3:] |
| | knowledge = splits[2] |
| | response = splits[3] |
| | |
| | if knowledge == "no_passages_used": |
| | |
| | continue |
| | if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge): |
| | |
| | continue |
| | if data_type != "wow_seen" and topic not in knowledge: |
| | |
| | continue |
| |
|
| | |
| | last_turn = turns[-1] |
| | instance = "( " + last_turn + " ) " + topic + " => " + knowledge |
| | |
| | |
| | dialog_example = "" |
| | if data_type != "wow_seen": |
| | dialog_example += "( " + topic + " ) " |
| | for i, turn in enumerate(turns): |
| | if i != 0: |
| | dialog_example += " " |
| | dialog_example += turn |
| | |
| | |
| | if topic in test_topics: |
| | if topic not in train_data_by_topic: |
| | train_data_by_topic[topic] = [instance] |
| | else: |
| | train_data_by_topic[topic].append(instance) |
| | |
| | if topic not in dialog_data_by_topic: |
| | dialog_data_by_topic[topic] = [dialog_example] |
| | else: |
| | dialog_data_by_topic[topic].append(dialog_example) |
| | |
| | else: |
| | |
| | if len(knowledge.split()) > 20: |
| | |
| | continue |
| | if knowledge.startswith("It") or knowledge.startswith("it") or \ |
| | knowledge.startswith("This") or knowledge.startswith("this"): |
| | continue |
| | |
| | |
| | dialog_examples.append((topic, dialog_example, instance)) |
| |
|
| | return train_data_by_topic, dialog_data_by_topic, dialog_examples |
| |
|
| |
|
| | emb_dict = {} |
| | def select_prompts_based_on_similarity( |
| | query, dialog_list, prompt_list, topic, tokenizer, encoder, topk): |
| | """Select samples based on the similarity""" |
| |
|
| | with torch.no_grad(): |
| | |
| | query_ids = tokenizer.encode(query) |
| | query_ids = torch.LongTensor([query_ids]).cuda() |
| | query_emb = encoder(input_ids=query_ids).pooler_output |
| | query_emb = query_emb[0] |
| | |
| | |
| | if topic in emb_dict: |
| | example_embeddings = emb_dict[topic] |
| | example_embeddings = example_embeddings.cuda() |
| | else: |
| | for idx, example in enumerate(dialog_list): |
| | example_ids = tokenizer.encode(example) |
| | example_ids = torch.LongTensor([example_ids]).cuda() |
| | example_emb = encoder(input_ids=example_ids).pooler_output |
| | if idx == 0: |
| | example_embeddings = example_emb |
| | else: |
| | example_embeddings = torch.cat( |
| | (example_embeddings, example_emb), dim=0) |
| | emb_dict[topic] = example_embeddings.cpu() |
| |
|
| | |
| | similarity_list = example_embeddings.matmul(query_emb) |
| | _, indices = torch.topk(similarity_list, k=topk) |
| | |
| | indices = indices.tolist() |
| | indices = indices[::-1] |
| | selected_prompts = [] |
| | for index in indices: |
| | |
| | selected_prompts.append(prompt_list[index]) |
| |
|
| | return selected_prompts |
| |
|
| |
|
| | def prompt_selection_for_knowledge_generation( |
| | test_datapath, train_datapath, model_path, output_prompt_path, data_type): |
| | """Selecting prompts for the knowledge generation""" |
| |
|
| | print("> Selecting prompts for the knowledge generation") |
| |
|
| | train_data_by_topic, dialog_data_by_topic, dialog_examples = \ |
| | get_database(test_datapath, train_datapath, data_type) |
| | |
| | from transformers import DPRQuestionEncoderTokenizer |
| | print("> loading tokenizer and encoder") |
| | tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( |
| | 'facebook/dpr-question_encoder-single-nq-base') |
| | encoder = torch.load(model_path).cuda() |
| |
|
| | print("> getting dialog embeddings") |
| | with torch.no_grad(): |
| | for idx, example in tqdm(enumerate(dialog_examples)): |
| | dialog = example[1] |
| | dialog_ids = tokenizer.encode(dialog) |
| | dialog_ids = torch.LongTensor([dialog_ids]).cuda() |
| | dialog_emb = encoder(input_ids=dialog_ids).pooler_output |
| |
|
| | if idx == 0: |
| | dialog_embeddings = dialog_emb |
| | else: |
| | dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0) |
| |
|
| | print("> reading test data from %s" % test_datapath) |
| | prompt_list_for_each_sample = [] |
| | with open(test_datapath, "r") as f: |
| | for i, line in tqdm(enumerate(f)): |
| | line = line.strip() |
| |
|
| | splits = line.split("\t") |
| | topic = splits[0] |
| | turns = splits[1].split(" [SEP] ")[-3:] |
| |
|
| | |
| | query_sent = "" |
| | if data_type != "seen": |
| | query_sent += "( " + topic + " ) " |
| | for i, turn in enumerate(turns): |
| | if i != 0: |
| | query_sent += " " |
| | query_sent += turn |
| |
|
| | if topic not in train_data_by_topic: |
| | |
| | query_ids = tokenizer.encode(query_sent) |
| | query_ids = torch.LongTensor([query_ids]).cuda() |
| | query_emb = encoder(input_ids=query_ids).pooler_output |
| | query_emb = query_emb[0] |
| |
|
| | |
| | similarity_list = dialog_embeddings.matmul(query_emb) |
| | _, indices = torch.sort(similarity_list) |
| | indices = indices.tolist() |
| | selected_topics = {} |
| | selected_prompts = [] |
| | num_prompt = 0 |
| | for index in indices: |
| | example = dialog_examples[index] |
| | topic_temp = example[0] |
| | if topic_temp not in selected_topics: |
| | selected_topics[topic_temp] = True |
| | selected_prompts.append(example[2]) |
| | num_prompt += 1 |
| | if num_prompt == 10: |
| | break |
| | |
| | |
| | example_list = selected_prompts[::-1] |
| | key = topic + " " + turns[-1] |
| | prompt_list_for_each_sample.append({key: example_list}) |
| |
|
| | else: |
| | num_data_sample = min(len(train_data_by_topic[topic]), 10) |
| | total_example_list = train_data_by_topic[topic] |
| | |
| | dialog_list = dialog_data_by_topic[topic] |
| | assert len(dialog_list) == len(train_data_by_topic[topic]) |
| |
|
| | |
| | example_list = select_prompts_based_on_similarity( |
| | query_sent, dialog_list, total_example_list, |
| | topic, tokenizer, encoder, topk=num_data_sample) |
| | |
| | key = topic + " " + turns[-1] |
| | prompt_list_for_each_sample.append({key: example_list}) |
| |
|
| | print("writing to %s" % output_prompt_path) |
| | with open(output_prompt_path, "w") as f: |
| | for instance in tqdm(prompt_list_for_each_sample): |
| | json.dump(instance, f) |
| | f.write("\n") |
| |
|
| |
|
| | def prompt_selection_for_response_generation(input_path, output_path, seed): |
| | """Selecting prompts for the response generation""" |
| |
|
| | print("> Selecting prompts for the response generation") |
| | print("> set random seed") |
| | np.random.seed(seed) |
| |
|
| | prompt_example_list = [] |
| | print("> reading data from %s" % input_path) |
| | with open(input_path, "r") as f: |
| | for i, line in tqdm(enumerate(f)): |
| | line = line.strip() |
| | splits = line.split("\t") |
| |
|
| | |
| | topic = splits[0] |
| | dialog_context = splits[1] |
| | knowledge = splits[2] |
| | response = splits[3] |
| | turns = dialog_context.split(" [SEP] ")[-3:] |
| | if knowledge == "no_passages_used": |
| | continue |
| |
|
| | |
| | from nltk import word_tokenize |
| | knowledge_sent_token_list = word_tokenize(knowledge) |
| | knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list} |
| | knowledge_len = len(knowledge_sent_token_list) |
| | response_token_list = word_tokenize(response) |
| | response_len = len(response_token_list) |
| | num_overlap_token = 0 |
| | accumulator = 0 |
| | for token in response_token_list: |
| | if token in knowledge_sent_token_dict: |
| | accumulator += 1 |
| | else: |
| | if accumulator >= 10: |
| | num_overlap_token += accumulator |
| | accumulator = 0 |
| | if accumulator >= 10: |
| | num_overlap_token += accumulator |
| | |
| | |
| | if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6: |
| | continue |
| | if num_overlap_token < knowledge_len * 0.8: |
| | continue |
| | |
| | last_turn = " ".join(word_tokenize(turns[-1])) |
| | knowledge = " ".join(word_tokenize(knowledge)) |
| | response = " ".join(word_tokenize(response)) |
| | prompt_example = "" |
| | |
| | prompt_example += "Topic: " + topic + ". " |
| | prompt_example += "User says: " + last_turn + " " |
| | prompt_example += "We know that: " + knowledge + " " |
| | prompt_example += "System replies: " + response |
| | |
| | prompt_example_list.append(prompt_example) |
| | |
| | |
| | np.random.shuffle(prompt_example_list) |
| | |
| | print("> writing to %s" % output_path) |
| | with open(output_path, "w") as f: |
| | |
| | for i in tqdm(range(20)): |
| | example = prompt_example_list[i] |
| | f.write(example + "\n") |
| |
|
| |
|
| | def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file): |
| | """Preparing inputs for the response generation""" |
| |
|
| | print("> Reading knowledge file from %s" % knwl_gen_file) |
| | |
| | with open(knwl_gen_file, "r") as f: |
| | knowledge_list = f.readlines() |
| | |
| | print("> Processing ...") |
| | with open(test_file, "r") as fr: |
| | with open(processed_file, "w") as fw: |
| | for line_num, line in enumerate(tqdm(fr)): |
| | line = line.strip() |
| | splits = line.split("\t") |
| | |
| | topic = splits[0] |
| | dialog_context = splits[1] |
| | response = splits[3] |
| | knowledge = knowledge_list[line_num] |
| | knowledge = knowledge.strip() |
| | if "<|endoftext|>" in knowledge: |
| | knowledge = knowledge.replace("<|endoftext|>", "") |
| |
|
| | |
| | fw.write(topic + "\t" + dialog_context + "\t" \ |
| | + knowledge + "\t" + response + "\n") |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | args = get_args() |
| | if args.func == "process_wow_dataset": |
| | process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file) |
| |
|
| | elif args.func == "process_woi_dataset": |
| | process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file) |
| |
|
| | elif args.func == "get_knwl_gen_prompts": |
| | prompt_selection_for_knowledge_generation( |
| | args.test_file, args.train_file, args.model_file, |
| | args.processed_file, args.data_type) |
| | |
| | elif args.func == "get_resp_gen_prompts": |
| | prompt_selection_for_response_generation( |
| | args.train_file, args.processed_file, args.seed) |
| |
|
| | elif args.func == "prepare_input": |
| | prepare_input_for_response_generation( |
| | args.test_file, args.knwl_gen_file, args.processed_file) |
| |
|