Spaces:
Runtime error
Runtime error
Commit ·
14b8b1d
1
Parent(s): e535922
refresh
Browse files- get_dataset.py +0 -68
- logger.py +0 -60
- prompt_concat.py +0 -170
- retrieve_dialog.py +0 -135
- src/retrieve_dialog.py +3 -2
- utils.py +0 -59
get_dataset.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
import sys
|
| 3 |
-
sys.path.append("../")
|
| 4 |
-
|
| 5 |
-
from collections import defaultdict
|
| 6 |
-
from .utils import is_float, load_txt
|
| 7 |
-
|
| 8 |
-
import random
|
| 9 |
-
|
| 10 |
-
random.seed(1234)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class CreateDataset:
|
| 14 |
-
def __init__(self, max_input_len=1500):
|
| 15 |
-
self.prompt = load_txt("../prompt/dataset_character.txt")
|
| 16 |
-
self.max_input_len = max_input_len # 小于(seq-length)-(max-gen-length)
|
| 17 |
-
self.example_split_flag = f"\n{'-' * 20}\n"
|
| 18 |
-
|
| 19 |
-
self.dataset = defaultdict(list)
|
| 20 |
-
self.manual_dataset = []
|
| 21 |
-
|
| 22 |
-
@staticmethod
|
| 23 |
-
def choose_examples(similar_examples,
|
| 24 |
-
max_length,
|
| 25 |
-
train_flag=False,
|
| 26 |
-
dialog=None,
|
| 27 |
-
example_split_flag=f"\n{'-' * 20}\n"):
|
| 28 |
-
if isinstance(similar_examples, str):
|
| 29 |
-
new_similar_examples = [x.strip() for x in similar_examples.split(example_split_flag)]
|
| 30 |
-
else:
|
| 31 |
-
# 去重
|
| 32 |
-
new_similar_examples = []
|
| 33 |
-
for example in similar_examples:
|
| 34 |
-
if (isinstance(example, list) or isinstance(example, tuple)) and len(example) == 2 and is_float(
|
| 35 |
-
example[0]):
|
| 36 |
-
# 包含score
|
| 37 |
-
example = example[1]
|
| 38 |
-
|
| 39 |
-
try:
|
| 40 |
-
example = "\n".join(example).strip()
|
| 41 |
-
except TypeError:
|
| 42 |
-
raise TypeError(f"example: {example}")
|
| 43 |
-
if train_flag and dialog and (example in dialog or dialog in example):
|
| 44 |
-
continue
|
| 45 |
-
|
| 46 |
-
# example去重
|
| 47 |
-
if train_flag:
|
| 48 |
-
# 部分相似也去掉
|
| 49 |
-
flag = False
|
| 50 |
-
for n_example in new_similar_examples:
|
| 51 |
-
if example in n_example or n_example in example:
|
| 52 |
-
flag = True
|
| 53 |
-
break
|
| 54 |
-
if not flag:
|
| 55 |
-
new_similar_examples.append(example)
|
| 56 |
-
else:
|
| 57 |
-
if example not in new_similar_examples:
|
| 58 |
-
new_similar_examples.append(example)
|
| 59 |
-
|
| 60 |
-
results = []
|
| 61 |
-
total_length = 0
|
| 62 |
-
for example in new_similar_examples:
|
| 63 |
-
total_length += len(example) if not total_length else len(example_split_flag) + len(example)
|
| 64 |
-
if total_length > max_length:
|
| 65 |
-
break
|
| 66 |
-
results.append(example)
|
| 67 |
-
results = example_split_flag.join(results).strip()
|
| 68 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
from logging.handlers import TimedRotatingFileHandler
|
| 3 |
-
|
| 4 |
-
import os
|
| 5 |
-
import sys
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class LoggerFactory:
|
| 10 |
-
|
| 11 |
-
@staticmethod
|
| 12 |
-
def create_logger(name=None, level=logging.INFO):
|
| 13 |
-
"""create a logger
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
name (str): name of the logger
|
| 17 |
-
level: level of logger
|
| 18 |
-
|
| 19 |
-
Raises:
|
| 20 |
-
ValueError is name is None
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
if name is None:
|
| 24 |
-
raise ValueError("name for logger cannot be None")
|
| 25 |
-
|
| 26 |
-
formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] "
|
| 27 |
-
"[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
|
| 28 |
-
|
| 29 |
-
logger_ = logging.getLogger(name)
|
| 30 |
-
logger_.setLevel(level)
|
| 31 |
-
logger_.propagate = False
|
| 32 |
-
ch = logging.StreamHandler(stream=sys.stdout)
|
| 33 |
-
ch.setLevel(level)
|
| 34 |
-
ch.setFormatter(formatter)
|
| 35 |
-
logger_.addHandler(ch)
|
| 36 |
-
return logger_
|
| 37 |
-
|
| 38 |
-
@staticmethod
|
| 39 |
-
def create_logger_with_file(log_file_path: str = None, logger_level=logging.INFO):
|
| 40 |
-
logger_inner = logging.getLogger()
|
| 41 |
-
logger_inner.setLevel(logger_level)
|
| 42 |
-
logger_inner.propagate = True
|
| 43 |
-
|
| 44 |
-
formatter = logging.Formatter(fmt="[%(asctime)s] [%(filename)s:%(lineno)s - %(levelname)s] %(message)s",
|
| 45 |
-
datefmt="%Y-%m-%d %H:%M:%S")
|
| 46 |
-
|
| 47 |
-
# TimedRotatingFileHandler
|
| 48 |
-
if log_file_path:
|
| 49 |
-
basedir = os.path.dirname(log_file_path)
|
| 50 |
-
if not os.path.isdir(basedir):
|
| 51 |
-
os.makedirs(basedir, exist_ok=True)
|
| 52 |
-
handler_file = TimedRotatingFileHandler(log_file_path, when="d", interval=1, backupCount=30)
|
| 53 |
-
handler_file.setFormatter(formatter)
|
| 54 |
-
logger_inner.addHandler(handler_file)
|
| 55 |
-
|
| 56 |
-
# StreamHandler
|
| 57 |
-
handler_console = logging.StreamHandler()
|
| 58 |
-
handler_console.setFormatter(formatter)
|
| 59 |
-
logger_inner.addHandler(handler_console)
|
| 60 |
-
return logger_inner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_concat.py
DELETED
|
@@ -1,170 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
from copy import deepcopy
|
| 3 |
-
from .get_dataset import CreateDataset
|
| 4 |
-
from .logger import LoggerFactory
|
| 5 |
-
from .retrieve_dialog import RetrieveDialog
|
| 6 |
-
from .utils import load_json, load_txt, save_to_json
|
| 7 |
-
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
|
| 11 |
-
logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class GetManualTestSamples:
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
role_name,
|
| 18 |
-
role_data_path,
|
| 19 |
-
save_samples_dir,
|
| 20 |
-
save_samples_path=None,
|
| 21 |
-
prompt_path="dataset_character.txt",
|
| 22 |
-
max_seq_len=4000,
|
| 23 |
-
retrieve_num=20,
|
| 24 |
-
):
|
| 25 |
-
self.role_name = role_name.strip()
|
| 26 |
-
self.role_data = load_json(role_data_path)
|
| 27 |
-
self.role_info = self.role_data[0]["role_info"].strip()
|
| 28 |
-
|
| 29 |
-
self.prompt = load_txt(prompt_path)
|
| 30 |
-
self.prompt = self.prompt.replace("${role_name}", self.role_name)
|
| 31 |
-
self.prompt = self.prompt.replace("${role_info}",
|
| 32 |
-
f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
|
| 33 |
-
|
| 34 |
-
self.retrieve_num = retrieve_num
|
| 35 |
-
self.retrieve = RetrieveDialog(role_name=self.role_name,
|
| 36 |
-
raw_dialog_list=[d["dialog"] for d in self.role_data],
|
| 37 |
-
retrieve_num=retrieve_num)
|
| 38 |
-
|
| 39 |
-
self.max_seq_len = max_seq_len
|
| 40 |
-
if not save_samples_path:
|
| 41 |
-
save_samples_path = f"{self.role_name}.json"
|
| 42 |
-
self.save_samples_path = os.path.join(save_samples_dir, save_samples_path)
|
| 43 |
-
|
| 44 |
-
def _add_simi_dialog(self, history: list, content_length):
|
| 45 |
-
retrieve_results = self.retrieve.get_retrieve_res(history, self.retrieve_num)
|
| 46 |
-
simi_dialogs = deepcopy(retrieve_results)
|
| 47 |
-
|
| 48 |
-
if simi_dialogs:
|
| 49 |
-
simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
|
| 50 |
-
max_length=self.max_seq_len - content_length,
|
| 51 |
-
train_flag=False)
|
| 52 |
-
logger.debug(f"retrieve_results: {retrieve_results}\nsimi_dialogs: {simi_dialogs}.")
|
| 53 |
-
return simi_dialogs, retrieve_results
|
| 54 |
-
|
| 55 |
-
def get_qa_samples_by_file(self,
|
| 56 |
-
questions_path,
|
| 57 |
-
user_name="user",
|
| 58 |
-
keep_retrieve_results_flag=False
|
| 59 |
-
):
|
| 60 |
-
questions = load_txt(questions_path).splitlines()
|
| 61 |
-
samples = []
|
| 62 |
-
for question in questions:
|
| 63 |
-
question = question.replace('\\n', "\n")
|
| 64 |
-
query = f"{user_name}:{question}" if ":" not in question else question
|
| 65 |
-
content = self.prompt.replace("${dialog}", query)
|
| 66 |
-
content = content.replace("${user_name}", user_name).strip()
|
| 67 |
-
|
| 68 |
-
history = [query]
|
| 69 |
-
simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
|
| 70 |
-
|
| 71 |
-
sample = {
|
| 72 |
-
"role_name": self.role_name,
|
| 73 |
-
"role_info": self.role_info,
|
| 74 |
-
"user_name": user_name,
|
| 75 |
-
"dialog": history,
|
| 76 |
-
"simi_dialogs": simi_dialogs,
|
| 77 |
-
}
|
| 78 |
-
if keep_retrieve_results_flag and retrieve_results:
|
| 79 |
-
sample["retrieve_results"] = retrieve_results
|
| 80 |
-
samples.append(sample)
|
| 81 |
-
self._save_samples(samples)
|
| 82 |
-
|
| 83 |
-
def get_qa_samples_by_query(self,
|
| 84 |
-
questions_query,
|
| 85 |
-
user_name="user",
|
| 86 |
-
keep_retrieve_results_flag=False
|
| 87 |
-
):
|
| 88 |
-
question = questions_query
|
| 89 |
-
samples = []
|
| 90 |
-
question = question.replace('\\n', "\n")
|
| 91 |
-
query = f"{user_name}: {question}" if ":" not in question else question
|
| 92 |
-
content = self.prompt.replace("${dialog}", query)
|
| 93 |
-
content = content.replace("${user_name}", user_name).strip()
|
| 94 |
-
|
| 95 |
-
history = [query]
|
| 96 |
-
simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
|
| 97 |
-
|
| 98 |
-
sample = {
|
| 99 |
-
"role_name": self.role_name,
|
| 100 |
-
"role_info": self.role_info,
|
| 101 |
-
"user_name": user_name,
|
| 102 |
-
"dialog": history,
|
| 103 |
-
"simi_dialogs": simi_dialogs,
|
| 104 |
-
}
|
| 105 |
-
if keep_retrieve_results_flag and retrieve_results:
|
| 106 |
-
sample["retrieve_results"] = retrieve_results
|
| 107 |
-
samples.append(sample)
|
| 108 |
-
self._save_samples(samples)
|
| 109 |
-
|
| 110 |
-
def _save_samples(self, samples):
|
| 111 |
-
data = samples
|
| 112 |
-
save_to_json(data, self.save_samples_path)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
class CreateTestDataset:
|
| 116 |
-
def __init__(self,
|
| 117 |
-
role_name,
|
| 118 |
-
role_samples_path=None,
|
| 119 |
-
role_data_path=None,
|
| 120 |
-
prompt_path="dataset_character.txt",
|
| 121 |
-
max_seq_len=4000):
|
| 122 |
-
self.max_seq_len = max_seq_len
|
| 123 |
-
self.role_name = role_name
|
| 124 |
-
|
| 125 |
-
self.prompt = load_txt(prompt_path)
|
| 126 |
-
self.prompt = self.prompt.replace("${role_name}", role_name).strip()
|
| 127 |
-
|
| 128 |
-
if not role_data_path:
|
| 129 |
-
print("need role_data_path, check please!")
|
| 130 |
-
self.default_simi_dialogs = None
|
| 131 |
-
if os.path.exists(role_data_path):
|
| 132 |
-
data = load_json(role_data_path)
|
| 133 |
-
role_info = data[0]["role_info"]
|
| 134 |
-
else:
|
| 135 |
-
raise ValueError(f"{self.role_name} didn't find role_info.")
|
| 136 |
-
self.role_info = role_info
|
| 137 |
-
self.prompt = self.prompt.replace("${role_info}", f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
|
| 138 |
-
|
| 139 |
-
if role_samples_path:
|
| 140 |
-
self.role_samples_path = role_samples_path
|
| 141 |
-
else:
|
| 142 |
-
print("check role_samples_path please!")
|
| 143 |
-
|
| 144 |
-
def load_samples(self):
|
| 145 |
-
samples = load_json(self.role_samples_path)
|
| 146 |
-
results = []
|
| 147 |
-
for sample in samples:
|
| 148 |
-
input_text = self.prompt
|
| 149 |
-
|
| 150 |
-
simi_dialogs = sample.get("simi_dialogs", None)
|
| 151 |
-
if not simi_dialogs:
|
| 152 |
-
simi_dialogs = self.default_simi_dialogs
|
| 153 |
-
if not simi_dialogs:
|
| 154 |
-
raise ValueError(f"didn't find simi_dialogs.")
|
| 155 |
-
simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
|
| 156 |
-
max_length=self.max_seq_len - len(input_text),
|
| 157 |
-
train_flag=False)
|
| 158 |
-
|
| 159 |
-
input_text = input_text.replace("${simi_dialog}", simi_dialogs)
|
| 160 |
-
user_name = sample.get("user_name", "user")
|
| 161 |
-
input_text = input_text.replace("${user_name}", user_name)
|
| 162 |
-
|
| 163 |
-
dialog = "\n".join(sample["dialog"]) if isinstance(sample["dialog"], list) else sample["dialog"]
|
| 164 |
-
input_text = input_text.replace("${dialog}", dialog)
|
| 165 |
-
|
| 166 |
-
assert len(input_text) < self.max_seq_len
|
| 167 |
-
results.append({
|
| 168 |
-
"input_text": input_text,
|
| 169 |
-
})
|
| 170 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieve_dialog.py
DELETED
|
@@ -1,135 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
from sentence_transformers import SentenceTransformer
|
| 3 |
-
from .utils import load_json
|
| 4 |
-
|
| 5 |
-
import faiss
|
| 6 |
-
import logging
|
| 7 |
-
import os
|
| 8 |
-
import re
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class RetrieveDialog:
|
| 15 |
-
def __init__(self,
|
| 16 |
-
role_name,
|
| 17 |
-
raw_dialog_list: list = None,
|
| 18 |
-
retrieve_num=20,
|
| 19 |
-
min_mean_role_utter_length=10):
|
| 20 |
-
if torch.cuda.is_available():
|
| 21 |
-
gpu_id = 0
|
| 22 |
-
torch.cuda.set_device(gpu_id)
|
| 23 |
-
|
| 24 |
-
assert raw_dialog_list
|
| 25 |
-
|
| 26 |
-
self.role_name = role_name
|
| 27 |
-
self.min_mean_role_utter_length = min_mean_role_utter_length
|
| 28 |
-
self.retrieve_num = retrieve_num
|
| 29 |
-
|
| 30 |
-
# config = load_json("config/config.json")
|
| 31 |
-
# local_dir = config["bge_local_path"]
|
| 32 |
-
local_dir = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
|
| 33 |
-
|
| 34 |
-
if not os.path.exists(local_dir):
|
| 35 |
-
print("Please download bge-large-zh-v1.5 first!")
|
| 36 |
-
self.emb_model = SentenceTransformer(local_dir)
|
| 37 |
-
|
| 38 |
-
self.dialogs, self.context_index = self._get_emb_base_by_list(raw_dialog_list)
|
| 39 |
-
|
| 40 |
-
logger.info(f"dialog db num: {len(self.dialogs)}")
|
| 41 |
-
logger.info(f"RetrieveDialog init success.")
|
| 42 |
-
|
| 43 |
-
@staticmethod
|
| 44 |
-
def dialog_preprocess(dialog: list, role_name):
|
| 45 |
-
dialog_new = []
|
| 46 |
-
# 把人名替换掉,减少对检索的影响
|
| 47 |
-
user_names = []
|
| 48 |
-
role_utter_length = []
|
| 49 |
-
for num in range(len(dialog)):
|
| 50 |
-
utter = dialog[num]
|
| 51 |
-
try:
|
| 52 |
-
user_name, utter_txt = re.split('[::]', utter, maxsplit=1)
|
| 53 |
-
except ValueError as e:
|
| 54 |
-
logging.error(f"utter:{utter} can't find user_name.")
|
| 55 |
-
return None, None
|
| 56 |
-
|
| 57 |
-
if user_name != role_name:
|
| 58 |
-
if user_name not in user_names:
|
| 59 |
-
user_names.append(user_name)
|
| 60 |
-
index = user_names.index(user_name)
|
| 61 |
-
utter = utter.replace(user_name, f"user{index}", 1)
|
| 62 |
-
else:
|
| 63 |
-
role_utter_length.append(len(utter_txt))
|
| 64 |
-
dialog_new.append(utter)
|
| 65 |
-
return dialog_new, user_names, role_utter_length
|
| 66 |
-
|
| 67 |
-
def _get_emb_base_by_list(self, raw_dialog_list):
|
| 68 |
-
logger.info(f"raw dialog db num: {len(raw_dialog_list)}")
|
| 69 |
-
new_raw_dialog_list = []
|
| 70 |
-
context_list = []
|
| 71 |
-
|
| 72 |
-
# 为了兼容因为句长把所有对话都过滤掉的情况
|
| 73 |
-
new_raw_dialog_list_total = []
|
| 74 |
-
context_list_total = []
|
| 75 |
-
for raw_dialog in raw_dialog_list:
|
| 76 |
-
if not raw_dialog:
|
| 77 |
-
continue
|
| 78 |
-
|
| 79 |
-
end = 0
|
| 80 |
-
for x in raw_dialog[::-1]:
|
| 81 |
-
if x.startswith(self.role_name):
|
| 82 |
-
break
|
| 83 |
-
end += 1
|
| 84 |
-
|
| 85 |
-
raw_dialog = raw_dialog[:len(raw_dialog) - end]
|
| 86 |
-
new_dialog, user_names, role_utter_length = self.dialog_preprocess(raw_dialog, self.role_name)
|
| 87 |
-
if not new_dialog or not role_utter_length:
|
| 88 |
-
continue
|
| 89 |
-
|
| 90 |
-
if raw_dialog in new_raw_dialog_list_total:
|
| 91 |
-
continue
|
| 92 |
-
|
| 93 |
-
# 获得embedding时,不需要最后一句答案
|
| 94 |
-
context = "\n".join(new_dialog) if len(new_dialog) < 2 else "\n".join(new_dialog[:-1])
|
| 95 |
-
|
| 96 |
-
new_raw_dialog_list_total.append(raw_dialog)
|
| 97 |
-
context_list_total.append(context)
|
| 98 |
-
|
| 99 |
-
# 句长过滤
|
| 100 |
-
role_length_mean = sum(role_utter_length) / len(role_utter_length)
|
| 101 |
-
if role_length_mean < self.min_mean_role_utter_length:
|
| 102 |
-
continue
|
| 103 |
-
new_raw_dialog_list.append(raw_dialog)
|
| 104 |
-
context_list.append(context)
|
| 105 |
-
|
| 106 |
-
assert len(new_raw_dialog_list) == len(context_list)
|
| 107 |
-
logger.debug(f"new_raw_dialog num: {len(new_raw_dialog_list)}")
|
| 108 |
-
|
| 109 |
-
# 兼容样本过少的情况
|
| 110 |
-
if len(new_raw_dialog_list) < self.retrieve_num:
|
| 111 |
-
new_raw_dialog_list = new_raw_dialog_list_total
|
| 112 |
-
context_list = context_list_total
|
| 113 |
-
|
| 114 |
-
# 对话向量库
|
| 115 |
-
context_vectors = self.emb_model.encode(context_list, normalize_embeddings=True)
|
| 116 |
-
context_index = faiss.IndexFlatL2(context_vectors.shape[1])
|
| 117 |
-
context_index.add(context_vectors)
|
| 118 |
-
|
| 119 |
-
return new_raw_dialog_list, context_index
|
| 120 |
-
|
| 121 |
-
def get_retrieve_res(self, dialog: list, retrieve_num: int):
|
| 122 |
-
logger.debug(f"dialog: {dialog}")
|
| 123 |
-
|
| 124 |
-
# 同样去掉user name影响
|
| 125 |
-
dialog, _, _ = self.dialog_preprocess(dialog, self.role_name)
|
| 126 |
-
dialog_vector = self.emb_model.encode(["\n".join(dialog)], normalize_embeddings=True)
|
| 127 |
-
|
| 128 |
-
simi_dialog_distance, simi_dialog_index = self.context_index.search(
|
| 129 |
-
dialog_vector, min(retrieve_num, len(self.dialogs)))
|
| 130 |
-
simi_dialog_results = [
|
| 131 |
-
(str(simi_dialog_distance[0][num]), self.dialogs[index]) for num, index in enumerate(simi_dialog_index[0])
|
| 132 |
-
]
|
| 133 |
-
logger.debug(f"dialog retrieve res: {simi_dialog_results}")
|
| 134 |
-
|
| 135 |
-
return simi_dialog_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/retrieve_dialog.py
CHANGED
|
@@ -27,8 +27,9 @@ class RetrieveDialog:
|
|
| 27 |
self.min_mean_role_utter_length = min_mean_role_utter_length
|
| 28 |
self.retrieve_num = retrieve_num
|
| 29 |
|
| 30 |
-
config = load_json("config/config.json")
|
| 31 |
-
local_dir = config["bge_local_path"]
|
|
|
|
| 32 |
|
| 33 |
if not os.path.exists(local_dir):
|
| 34 |
print("Please download bge-large-zh-v1.5 first!")
|
|
|
|
| 27 |
self.min_mean_role_utter_length = min_mean_role_utter_length
|
| 28 |
self.retrieve_num = retrieve_num
|
| 29 |
|
| 30 |
+
# config = load_json("config/config.json")
|
| 31 |
+
# local_dir = config["bge_local_path"]
|
| 32 |
+
local_dir = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
|
| 33 |
|
| 34 |
if not os.path.exists(local_dir):
|
| 35 |
print("Please download bge-large-zh-v1.5 first!")
|
utils.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
import csv
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def read_csv_to_json(file_path, role_name, role_info):
|
| 8 |
-
json_list = []
|
| 9 |
-
|
| 10 |
-
with open(file_path, mode="r", newline="", encoding="utf-8") as csvfile:
|
| 11 |
-
csv_reader = csv.reader(csvfile)
|
| 12 |
-
_ = next(csv_reader)
|
| 13 |
-
|
| 14 |
-
for row in csv_reader:
|
| 15 |
-
json_object = {
|
| 16 |
-
"role_name": role_name,
|
| 17 |
-
"role_info": role_info,
|
| 18 |
-
"dialog": row[1].split("\n"),
|
| 19 |
-
}
|
| 20 |
-
json_list.append(json_object)
|
| 21 |
-
|
| 22 |
-
return json_list
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def save_json(json_list, output_path):
|
| 26 |
-
with open(output_path, "w", encoding="utf-8") as jsonfile:
|
| 27 |
-
json.dump(json_list, jsonfile, ensure_ascii=False, indent=4)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def decode_csv_to_json(role_data_path, role_name, role_info, json_output_path):
|
| 31 |
-
json_data = read_csv_to_json(role_data_path, role_name, role_info)
|
| 32 |
-
save_json(json_data, json_output_path)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def load_txt(path):
|
| 36 |
-
with open(path, "r", encoding="utf-8", errors="ignore") as file:
|
| 37 |
-
text = file.read()
|
| 38 |
-
return text
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def load_json(path):
|
| 42 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 43 |
-
data = json.load(f)
|
| 44 |
-
return data
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def save_to_json(data, filepath, flag="w"):
|
| 48 |
-
if not os.path.exists(os.path.dirname(filepath)):
|
| 49 |
-
os.makedirs(os.path.dirname(filepath))
|
| 50 |
-
with open(filepath, flag, encoding="utf-8") as f:
|
| 51 |
-
f.write(json.dumps(data, ensure_ascii=False, indent=3))
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def is_float(my_str):
|
| 55 |
-
try:
|
| 56 |
-
num = float(my_str)
|
| 57 |
-
return True
|
| 58 |
-
except ValueError:
|
| 59 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|