KevinHuSh
commited on
Commit
·
cfd888e
1
Parent(s):
bbbfe3a
deal with stop reason being length problem (#109)
Browse files- api/apps/conversation_app.py +8 -5
- api/apps/user_app.py +4 -2
- deepdoc/vision/t_recognizer.py +1 -1
- rag/app/presentation.py +6 -3
- rag/llm/chat_model.py +33 -13
- rag/nlp/search.py +7 -3
api/apps/conversation_app.py
CHANGED
|
@@ -176,7 +176,7 @@ def chat(dialog, messages, **kwargs):
|
|
| 176 |
if not llm:
|
| 177 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 178 |
llm = llm[0]
|
| 179 |
-
|
| 180 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
| 181 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
| 182 |
|
|
@@ -184,7 +184,7 @@ def chat(dialog, messages, **kwargs):
|
|
| 184 |
## try to use sql if field mapping is good to go
|
| 185 |
if field_map:
|
| 186 |
stat_logger.info("Use SQL to retrieval.")
|
| 187 |
-
markdown_tbl, chunks = use_sql(
|
| 188 |
if markdown_tbl:
|
| 189 |
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
| 190 |
|
|
@@ -195,7 +195,9 @@ def chat(dialog, messages, **kwargs):
|
|
| 195 |
if p["key"] not in kwargs:
|
| 196 |
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
| 197 |
|
| 198 |
-
|
|
|
|
|
|
|
| 199 |
dialog.similarity_threshold,
|
| 200 |
dialog.vector_similarity_weight, top=1024, aggs=False)
|
| 201 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
|
@@ -224,13 +226,14 @@ def chat(dialog, messages, **kwargs):
|
|
| 224 |
|
| 225 |
|
| 226 |
def use_sql(question, field_map, tenant_id, chat_mdl):
|
| 227 |
-
sys_prompt = "你是一个DBA
|
| 228 |
user_promt = """
|
| 229 |
表名:{};
|
| 230 |
数据库表字段说明如下:
|
| 231 |
{}
|
| 232 |
|
| 233 |
-
|
|
|
|
| 234 |
请写出SQL,且只要SQL,不要有其他说明及文字。
|
| 235 |
""".format(
|
| 236 |
index_name(tenant_id),
|
|
|
|
| 176 |
if not llm:
|
| 177 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 178 |
llm = llm[0]
|
| 179 |
+
questions = [m["content"] for m in messages if m["role"] == "user"]
|
| 180 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
| 181 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
| 182 |
|
|
|
|
| 184 |
## try to use sql if field mapping is good to go
|
| 185 |
if field_map:
|
| 186 |
stat_logger.info("Use SQL to retrieval.")
|
| 187 |
+
markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
|
| 188 |
if markdown_tbl:
|
| 189 |
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
| 190 |
|
|
|
|
| 195 |
if p["key"] not in kwargs:
|
| 196 |
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
| 197 |
|
| 198 |
+
for _ in range(len(questions)//2):
|
| 199 |
+
questions.append(questions[-1])
|
| 200 |
+
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
| 201 |
dialog.similarity_threshold,
|
| 202 |
dialog.vector_similarity_weight, top=1024, aggs=False)
|
| 203 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
|
|
|
| 226 |
|
| 227 |
|
| 228 |
def use_sql(question, field_map, tenant_id, chat_mdl):
|
| 229 |
+
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
|
| 230 |
user_promt = """
|
| 231 |
表名:{};
|
| 232 |
数据库表字段说明如下:
|
| 233 |
{}
|
| 234 |
|
| 235 |
+
问题如下:
|
| 236 |
+
{}
|
| 237 |
请写出SQL,且只要SQL,不要有其他说明及文字。
|
| 238 |
""".format(
|
| 239 |
index_name(tenant_id),
|
api/apps/user_app.py
CHANGED
|
@@ -100,12 +100,14 @@ def github_callback():
|
|
| 100 |
if len(users) > 1: raise Exception('Same E-mail exist!')
|
| 101 |
user = users[0]
|
| 102 |
login_user(user)
|
|
|
|
| 103 |
except Exception as e:
|
| 104 |
rollback_user_registration(user_id)
|
| 105 |
stat_logger.exception(e)
|
| 106 |
return redirect("/?error=%s"%str(e))
|
| 107 |
-
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
def user_info_from_github(access_token):
|
|
|
|
| 100 |
if len(users) > 1: raise Exception('Same E-mail exist!')
|
| 101 |
user = users[0]
|
| 102 |
login_user(user)
|
| 103 |
+
return redirect("/?auth=%s"%user.get_id())
|
| 104 |
except Exception as e:
|
| 105 |
rollback_user_registration(user_id)
|
| 106 |
stat_logger.exception(e)
|
| 107 |
return redirect("/?error=%s"%str(e))
|
| 108 |
+
user = users[0]
|
| 109 |
+
login_user(user)
|
| 110 |
+
return redirect("/?auth=%s" % user.get_id())
|
| 111 |
|
| 112 |
|
| 113 |
def user_info_from_github(access_token):
|
deepdoc/vision/t_recognizer.py
CHANGED
|
@@ -28,7 +28,7 @@ def main(args):
|
|
| 28 |
images, outputs = init_in_out(args)
|
| 29 |
if args.mode.lower() == "layout":
|
| 30 |
labels = LayoutRecognizer.labels
|
| 31 |
-
detr = Recognizer(labels, "layout
|
| 32 |
if args.mode.lower() == "tsr":
|
| 33 |
labels = TableStructureRecognizer.labels
|
| 34 |
detr = TableStructureRecognizer()
|
|
|
|
| 28 |
images, outputs = init_in_out(args)
|
| 29 |
if args.mode.lower() == "layout":
|
| 30 |
labels = LayoutRecognizer.labels
|
| 31 |
+
detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
| 32 |
if args.mode.lower() == "tsr":
|
| 33 |
labels = TableStructureRecognizer.labels
|
| 34 |
detr = TableStructureRecognizer()
|
rag/app/presentation.py
CHANGED
|
@@ -73,12 +73,13 @@ class Pdf(PdfParser):
|
|
| 73 |
return res
|
| 74 |
|
| 75 |
|
| 76 |
-
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
| 77 |
"""
|
| 78 |
The supported file formats are pdf, pptx.
|
| 79 |
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
| 80 |
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
| 81 |
"""
|
|
|
|
| 82 |
doc = {
|
| 83 |
"docnm_kwd": filename,
|
| 84 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
@@ -98,8 +99,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
|
|
| 98 |
for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)):
|
| 99 |
d = copy.deepcopy(doc)
|
| 100 |
d["image"] = img
|
| 101 |
-
d["
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
res.append(d)
|
| 104 |
return res
|
| 105 |
|
|
|
|
| 73 |
return res
|
| 74 |
|
| 75 |
|
| 76 |
+
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
| 77 |
"""
|
| 78 |
The supported file formats are pdf, pptx.
|
| 79 |
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
| 80 |
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
| 81 |
"""
|
| 82 |
+
eng = lang.lower() == "english"
|
| 83 |
doc = {
|
| 84 |
"docnm_kwd": filename,
|
| 85 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
|
|
| 99 |
for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)):
|
| 100 |
d = copy.deepcopy(doc)
|
| 101 |
d["image"] = img
|
| 102 |
+
d["page_num_int"] = [pn+1]
|
| 103 |
+
d["top_int"] = [0]
|
| 104 |
+
d["position_int"].append((pn + 1, 0, img.size[0], 0, img.size[1]))
|
| 105 |
+
tokenize(d, txt, eng)
|
| 106 |
res.append(d)
|
| 107 |
return res
|
| 108 |
|
rag/llm/chat_model.py
CHANGED
|
@@ -14,9 +14,13 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
from abc import ABC
|
|
|
|
|
|
|
| 17 |
from openai import OpenAI
|
| 18 |
import openai
|
| 19 |
|
|
|
|
|
|
|
| 20 |
|
| 21 |
class Base(ABC):
|
| 22 |
def __init__(self, key, model_name):
|
|
@@ -34,13 +38,17 @@ class GptTurbo(Base):
|
|
| 34 |
def chat(self, system, history, gen_conf):
|
| 35 |
if system: history.insert(0, {"role": "system", "content": system})
|
| 36 |
try:
|
| 37 |
-
|
| 38 |
model=self.model_name,
|
| 39 |
messages=history,
|
| 40 |
**gen_conf)
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
except openai.APIError as e:
|
| 43 |
-
return "ERROR
|
| 44 |
|
| 45 |
|
| 46 |
from dashscope import Generation
|
|
@@ -59,9 +67,16 @@ class QWenChat(Base):
|
|
| 59 |
result_format='message',
|
| 60 |
**gen_conf
|
| 61 |
)
|
|
|
|
|
|
|
| 62 |
if response.status_code == HTTPStatus.OK:
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
from zhipuai import ZhipuAI
|
|
@@ -73,11 +88,16 @@ class ZhipuChat(Base):
|
|
| 73 |
def chat(self, system, history, gen_conf):
|
| 74 |
from http import HTTPStatus
|
| 75 |
if system: history.insert(0, {"role": "system", "content": system})
|
| 76 |
-
|
| 77 |
-
self.
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
from abc import ABC
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
|
| 19 |
from openai import OpenAI
|
| 20 |
import openai
|
| 21 |
|
| 22 |
+
from rag.nlp import is_english
|
| 23 |
+
|
| 24 |
|
| 25 |
class Base(ABC):
|
| 26 |
def __init__(self, key, model_name):
|
|
|
|
| 38 |
def chat(self, system, history, gen_conf):
|
| 39 |
if system: history.insert(0, {"role": "system", "content": system})
|
| 40 |
try:
|
| 41 |
+
response = self.client.chat.completions.create(
|
| 42 |
model=self.model_name,
|
| 43 |
messages=history,
|
| 44 |
**gen_conf)
|
| 45 |
+
ans = response.output.choices[0]['message']['content'].strip()
|
| 46 |
+
if response.output.choices[0].get("finish_reason", "") == "length":
|
| 47 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 48 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 49 |
+
return ans, response.usage.completion_tokens
|
| 50 |
except openai.APIError as e:
|
| 51 |
+
return "**ERROR**: "+str(e), 0
|
| 52 |
|
| 53 |
|
| 54 |
from dashscope import Generation
|
|
|
|
| 67 |
result_format='message',
|
| 68 |
**gen_conf
|
| 69 |
)
|
| 70 |
+
ans = ""
|
| 71 |
+
tk_count = 0
|
| 72 |
if response.status_code == HTTPStatus.OK:
|
| 73 |
+
ans += response.output.choices[0]['message']['content']
|
| 74 |
+
tk_count += response.usage.output_tokens
|
| 75 |
+
if response.output.choices[0].get("finish_reason", "") == "length":
|
| 76 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 77 |
+
return ans, tk_count
|
| 78 |
+
|
| 79 |
+
return "**ERROR**: " + response.message, tk_count
|
| 80 |
|
| 81 |
|
| 82 |
from zhipuai import ZhipuAI
|
|
|
|
| 88 |
def chat(self, system, history, gen_conf):
|
| 89 |
from http import HTTPStatus
|
| 90 |
if system: history.insert(0, {"role": "system", "content": system})
|
| 91 |
+
try:
|
| 92 |
+
response = self.client.chat.completions.create(
|
| 93 |
+
self.model_name,
|
| 94 |
+
messages=history,
|
| 95 |
+
**gen_conf
|
| 96 |
+
)
|
| 97 |
+
ans = response.output.choices[0]['message']['content'].strip()
|
| 98 |
+
if response.output.choices[0].get("finish_reason", "") == "length":
|
| 99 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 100 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 101 |
+
return ans, response.usage.completion_tokens
|
| 102 |
+
except Exception as e:
|
| 103 |
+
return "**ERROR**: " + str(e), 0
|
rag/nlp/search.py
CHANGED
|
@@ -224,12 +224,13 @@ class Dealer:
|
|
| 224 |
chunks_tks,
|
| 225 |
tkweight, vtweight)
|
| 226 |
mx = np.max(sim) * 0.99
|
| 227 |
-
if mx < 0.
|
| 228 |
continue
|
| 229 |
cites[idx[i]] = list(
|
| 230 |
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
| 231 |
|
| 232 |
res = ""
|
|
|
|
| 233 |
for i, p in enumerate(pieces):
|
| 234 |
res += p
|
| 235 |
if i not in idx:
|
|
@@ -237,7 +238,10 @@ class Dealer:
|
|
| 237 |
if i not in cites:
|
| 238 |
continue
|
| 239 |
for c in cites[i]: assert int(c) < len(chunk_v)
|
| 240 |
-
for c in cites[i]:
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
return res
|
| 243 |
|
|
@@ -318,7 +322,7 @@ class Dealer:
|
|
| 318 |
if dnm not in ranks["doc_aggs"]:
|
| 319 |
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
| 320 |
ranks["doc_aggs"][dnm]["count"] += 1
|
| 321 |
-
ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
|
| 322 |
|
| 323 |
return ranks
|
| 324 |
|
|
|
|
| 224 |
chunks_tks,
|
| 225 |
tkweight, vtweight)
|
| 226 |
mx = np.max(sim) * 0.99
|
| 227 |
+
if mx < 0.66:
|
| 228 |
continue
|
| 229 |
cites[idx[i]] = list(
|
| 230 |
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
| 231 |
|
| 232 |
res = ""
|
| 233 |
+
seted = set([])
|
| 234 |
for i, p in enumerate(pieces):
|
| 235 |
res += p
|
| 236 |
if i not in idx:
|
|
|
|
| 238 |
if i not in cites:
|
| 239 |
continue
|
| 240 |
for c in cites[i]: assert int(c) < len(chunk_v)
|
| 241 |
+
for c in cites[i]:
|
| 242 |
+
if c in seted:continue
|
| 243 |
+
res += f" ##{c}$$"
|
| 244 |
+
seted.add(c)
|
| 245 |
|
| 246 |
return res
|
| 247 |
|
|
|
|
| 322 |
if dnm not in ranks["doc_aggs"]:
|
| 323 |
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
| 324 |
ranks["doc_aggs"][dnm]["count"] += 1
|
| 325 |
+
ranks["doc_aggs"] = []#[{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
|
| 326 |
|
| 327 |
return ranks
|
| 328 |
|