KevinHuSh
commited on
Commit
·
4c52eb9
1
Parent(s):
3772f42
refine admin initialization (#75)
Browse files- api/apps/chunk_app.py +2 -2
- api/apps/conversation_app.py +1 -3
- api/db/init_data.py +42 -4
- api/settings.py +5 -1
- deepdoc/parser/pdf_parser.py +1 -1
- deepdoc/vision/layout_recognizer.py +1 -1
- deepdoc/vision/postprocess.py +2 -3
- deepdoc/vision/recognizer.py +12 -0
- deepdoc/vision/t_recognizer.py +3 -1
- deepdoc/vision/table_structure_recognizer.py +5 -5
- rag/llm/chat_model.py +11 -8
- rag/nlp/__init__.py +2 -3
- rag/nlp/search.py +4 -2
api/apps/chunk_app.py
CHANGED
|
@@ -20,7 +20,7 @@ from flask_login import login_required, current_user
|
|
| 20 |
from elasticsearch_dsl import Q
|
| 21 |
|
| 22 |
from rag.app.qa import rmPrefix, beAdoc
|
| 23 |
-
from rag.nlp import search, huqie
|
| 24 |
from rag.utils import ELASTICSEARCH, rmSpace
|
| 25 |
from api.db import LLMType, ParserType
|
| 26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
@@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
|
|
| 28 |
from api.db.services.user_service import UserTenantService
|
| 29 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 30 |
from api.db.services.document_service import DocumentService
|
| 31 |
-
from api.settings import RetCode
|
| 32 |
from api.utils.api_utils import get_json_result
|
| 33 |
import hashlib
|
| 34 |
import re
|
|
|
|
| 20 |
from elasticsearch_dsl import Q
|
| 21 |
|
| 22 |
from rag.app.qa import rmPrefix, beAdoc
|
| 23 |
+
from rag.nlp import search, huqie
|
| 24 |
from rag.utils import ELASTICSEARCH, rmSpace
|
| 25 |
from api.db import LLMType, ParserType
|
| 26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
|
| 28 |
from api.db.services.user_service import UserTenantService
|
| 29 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 30 |
from api.db.services.document_service import DocumentService
|
| 31 |
+
from api.settings import RetCode, retrievaler
|
| 32 |
from api.utils.api_utils import get_json_result
|
| 33 |
import hashlib
|
| 34 |
import re
|
api/apps/conversation_app.py
CHANGED
|
@@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
|
|
| 21 |
from api.db import LLMType
|
| 22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 23 |
from api.db.services.llm_service import LLMService, LLMBundle
|
| 24 |
-
from api.settings import access_logger, stat_logger
|
| 25 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 26 |
from api.utils import get_uuid
|
| 27 |
from api.utils.api_utils import get_json_result
|
| 28 |
from rag.app.resume import forbidden_select_fields4resume
|
| 29 |
-
from rag.llm import ChatModel
|
| 30 |
-
from rag.nlp import retrievaler
|
| 31 |
from rag.nlp.search import index_name
|
| 32 |
from rag.utils import num_tokens_from_string, encoder, rmSpace
|
| 33 |
|
|
|
|
| 21 |
from api.db import LLMType
|
| 22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 23 |
from api.db.services.llm_service import LLMService, LLMBundle
|
| 24 |
+
from api.settings import access_logger, stat_logger, retrievaler
|
| 25 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 26 |
from api.utils import get_uuid
|
| 27 |
from api.utils.api_utils import get_json_result
|
| 28 |
from rag.app.resume import forbidden_select_fields4resume
|
|
|
|
|
|
|
| 29 |
from rag.nlp.search import index_name
|
| 30 |
from rag.utils import num_tokens_from_string, encoder, rmSpace
|
| 31 |
|
api/db/init_data.py
CHANGED
|
@@ -16,10 +16,12 @@
|
|
| 16 |
import time
|
| 17 |
import uuid
|
| 18 |
|
| 19 |
-
from api.db import LLMType
|
| 20 |
from api.db.db_models import init_database_tables as init_web_db
|
| 21 |
from api.db.services import UserService
|
| 22 |
-
from api.db.services.llm_service import LLMFactoriesService, LLMService
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def init_superuser():
|
|
@@ -32,8 +34,44 @@ def init_superuser():
|
|
| 32 |
"creator": "system",
|
| 33 |
"status": "1",
|
| 34 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
UserService.save(**user_info)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def init_llm_factory():
|
| 39 |
factory_infos = [{
|
|
@@ -171,10 +209,10 @@ def init_llm_factory():
|
|
| 171 |
|
| 172 |
def init_web_data():
|
| 173 |
start_time = time.time()
|
| 174 |
-
if not UserService.get_all().count():
|
| 175 |
-
init_superuser()
|
| 176 |
|
| 177 |
if not LLMService.get_all().count():init_llm_factory()
|
|
|
|
|
|
|
| 178 |
|
| 179 |
print("init web data success:{}".format(time.time() - start_time))
|
| 180 |
|
|
|
|
| 16 |
import time
|
| 17 |
import uuid
|
| 18 |
|
| 19 |
+
from api.db import LLMType, UserTenantRole
|
| 20 |
from api.db.db_models import init_database_tables as init_web_db
|
| 21 |
from api.db.services import UserService
|
| 22 |
+
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
| 23 |
+
from api.db.services.user_service import TenantService, UserTenantService
|
| 24 |
+
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
|
| 25 |
|
| 26 |
|
| 27 |
def init_superuser():
|
|
|
|
| 34 |
"creator": "system",
|
| 35 |
"status": "1",
|
| 36 |
}
|
| 37 |
+
tenant = {
|
| 38 |
+
"id": user_info["id"],
|
| 39 |
+
"name": user_info["nickname"] + "‘s Kingdom",
|
| 40 |
+
"llm_id": CHAT_MDL,
|
| 41 |
+
"embd_id": EMBEDDING_MDL,
|
| 42 |
+
"asr_id": ASR_MDL,
|
| 43 |
+
"parser_ids": PARSERS,
|
| 44 |
+
"img2txt_id": IMAGE2TEXT_MDL
|
| 45 |
+
}
|
| 46 |
+
usr_tenant = {
|
| 47 |
+
"tenant_id": user_info["id"],
|
| 48 |
+
"user_id": user_info["id"],
|
| 49 |
+
"invited_by": user_info["id"],
|
| 50 |
+
"role": UserTenantRole.OWNER
|
| 51 |
+
}
|
| 52 |
+
tenant_llm = []
|
| 53 |
+
for llm in LLMService.query(fid=LLM_FACTORY):
|
| 54 |
+
tenant_llm.append(
|
| 55 |
+
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
|
| 56 |
+
"api_key": API_KEY})
|
| 57 |
+
|
| 58 |
+
if not UserService.save(**user_info):
|
| 59 |
+
print("【ERROR】can't init admin.")
|
| 60 |
+
return
|
| 61 |
+
TenantService.save(**tenant)
|
| 62 |
+
UserTenantService.save(**usr_tenant)
|
| 63 |
+
TenantLLMService.insert_many(tenant_llm)
|
| 64 |
UserService.save(**user_info)
|
| 65 |
|
| 66 |
+
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
| 67 |
+
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
| 68 |
+
if msg.find("ERROR: ") == 0:
|
| 69 |
+
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
|
| 70 |
+
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
|
| 71 |
+
v,c = embd_mdl.encode(["Hello!"])
|
| 72 |
+
if c == 0:
|
| 73 |
+
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
|
| 74 |
+
|
| 75 |
|
| 76 |
def init_llm_factory():
|
| 77 |
factory_infos = [{
|
|
|
|
| 209 |
|
| 210 |
def init_web_data():
|
| 211 |
start_time = time.time()
|
|
|
|
|
|
|
| 212 |
|
| 213 |
if not LLMService.get_all().count():init_llm_factory()
|
| 214 |
+
if not UserService.get_all().count():
|
| 215 |
+
init_superuser()
|
| 216 |
|
| 217 |
print("init web data success:{}".format(time.time() - start_time))
|
| 218 |
|
api/settings.py
CHANGED
|
@@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
|
|
| 21 |
from api.utils.file_utils import get_project_base_directory
|
| 22 |
from api.utils.log_utils import LoggerFactory, getLogger
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
# Server
|
| 26 |
API_VERSION = "v1"
|
| 27 |
RAG_FLOW_SERVICE_NAME = "ragflow"
|
| 28 |
SERVER_MODULE = "rag_flow_server.py"
|
|
@@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
|
|
| 116 |
PRIVILEGE_COMMAND_WHITELIST = []
|
| 117 |
CHECK_NODES_IDENTITY = False
|
| 118 |
|
|
|
|
|
|
|
| 119 |
class CustomEnum(Enum):
|
| 120 |
@classmethod
|
| 121 |
def valid(cls, value):
|
|
|
|
| 21 |
from api.utils.file_utils import get_project_base_directory
|
| 22 |
from api.utils.log_utils import LoggerFactory, getLogger
|
| 23 |
|
| 24 |
+
from rag.nlp import search
|
| 25 |
+
from rag.utils import ELASTICSEARCH
|
| 26 |
+
|
| 27 |
|
|
|
|
| 28 |
API_VERSION = "v1"
|
| 29 |
RAG_FLOW_SERVICE_NAME = "ragflow"
|
| 30 |
SERVER_MODULE = "rag_flow_server.py"
|
|
|
|
| 118 |
PRIVILEGE_COMMAND_WHITELIST = []
|
| 119 |
CHECK_NODES_IDENTITY = False
|
| 120 |
|
| 121 |
+
retrievaler = search.Dealer(ELASTICSEARCH)
|
| 122 |
+
|
| 123 |
class CustomEnum(Enum):
|
| 124 |
@classmethod
|
| 125 |
def valid(cls, value):
|
deepdoc/parser/pdf_parser.py
CHANGED
|
@@ -230,7 +230,7 @@ class HuParser:
|
|
| 230 |
b["H_right"] = headers[ii]["x1"]
|
| 231 |
b["H"] = ii
|
| 232 |
|
| 233 |
-
ii = Recognizer.
|
| 234 |
if ii is not None:
|
| 235 |
b["C"] = ii
|
| 236 |
b["C_left"] = clmns[ii]["x0"]
|
|
|
|
| 230 |
b["H_right"] = headers[ii]["x1"]
|
| 231 |
b["H"] = ii
|
| 232 |
|
| 233 |
+
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
| 234 |
if ii is not None:
|
| 235 |
b["C"] = ii
|
| 236 |
b["C_left"] = clmns[ii]["x0"]
|
deepdoc/vision/layout_recognizer.py
CHANGED
|
@@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer):
|
|
| 37 |
super().__init__(self.labels, domain,
|
| 38 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
| 39 |
|
| 40 |
-
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.
|
| 41 |
def __is_garbage(b):
|
| 42 |
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
| 43 |
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
|
|
|
| 37 |
super().__init__(self.labels, domain,
|
| 38 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
| 39 |
|
| 40 |
+
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
| 41 |
def __is_garbage(b):
|
| 42 |
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
| 43 |
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
deepdoc/vision/postprocess.py
CHANGED
|
@@ -2,7 +2,6 @@ import copy
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
-
import paddle
|
| 6 |
from shapely.geometry import Polygon
|
| 7 |
import pyclipper
|
| 8 |
|
|
@@ -215,7 +214,7 @@ class DBPostProcess(object):
|
|
| 215 |
|
| 216 |
def __call__(self, outs_dict, shape_list):
|
| 217 |
pred = outs_dict['maps']
|
| 218 |
-
if isinstance(pred,
|
| 219 |
pred = pred.numpy()
|
| 220 |
pred = pred[:, 0, :, :]
|
| 221 |
segmentation = pred > self.thresh
|
|
@@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|
| 339 |
def __call__(self, preds, label=None, *args, **kwargs):
|
| 340 |
if isinstance(preds, tuple) or isinstance(preds, list):
|
| 341 |
preds = preds[-1]
|
| 342 |
-
if isinstance(preds,
|
| 343 |
preds = preds.numpy()
|
| 344 |
preds_idx = preds.argmax(axis=2)
|
| 345 |
preds_prob = preds.max(axis=2)
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
|
|
|
| 5 |
from shapely.geometry import Polygon
|
| 6 |
import pyclipper
|
| 7 |
|
|
|
|
| 214 |
|
| 215 |
def __call__(self, outs_dict, shape_list):
|
| 216 |
pred = outs_dict['maps']
|
| 217 |
+
if not isinstance(pred, np.ndarray):
|
| 218 |
pred = pred.numpy()
|
| 219 |
pred = pred[:, 0, :, :]
|
| 220 |
segmentation = pred > self.thresh
|
|
|
|
| 338 |
def __call__(self, preds, label=None, *args, **kwargs):
|
| 339 |
if isinstance(preds, tuple) or isinstance(preds, list):
|
| 340 |
preds = preds[-1]
|
| 341 |
+
if not isinstance(preds, np.ndarray):
|
| 342 |
preds = preds.numpy()
|
| 343 |
preds_idx = preds.argmax(axis=2)
|
| 344 |
preds_prob = preds.max(axis=2)
|
deepdoc/vision/recognizer.py
CHANGED
|
@@ -259,6 +259,18 @@ class Recognizer(object):
|
|
| 259 |
|
| 260 |
return max_overlaped_i
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
@staticmethod
|
| 263 |
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
| 264 |
if not boxes:
|
|
|
|
| 259 |
|
| 260 |
return max_overlaped_i
|
| 261 |
|
| 262 |
+
@staticmethod
|
| 263 |
+
def find_horizontally_tightest_fit(box, boxes):
|
| 264 |
+
if not boxes:
|
| 265 |
+
return
|
| 266 |
+
min_dis, min_i = 1000000, None
|
| 267 |
+
for i,b in enumerate(boxes):
|
| 268 |
+
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
| 269 |
+
if dis < min_dis:
|
| 270 |
+
min_i = i
|
| 271 |
+
min_dis = dis
|
| 272 |
+
return min_i
|
| 273 |
+
|
| 274 |
@staticmethod
|
| 275 |
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
| 276 |
if not boxes:
|
deepdoc/vision/t_recognizer.py
CHANGED
|
@@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
| 74 |
clmns = sorted([r for r in tb_cpns if re.match(
|
| 75 |
r"table column$", r["label"])], key=lambda x: x["x0"])
|
| 76 |
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
|
|
|
| 77 |
for b in boxes:
|
| 78 |
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
| 79 |
if ii is not None:
|
|
@@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
| 89 |
b["H_right"] = headers[ii]["x1"]
|
| 90 |
b["H"] = ii
|
| 91 |
|
| 92 |
-
ii = Recognizer.
|
| 93 |
if ii is not None:
|
| 94 |
b["C"] = ii
|
| 95 |
b["C_left"] = clmns[ii]["x0"]
|
|
@@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
| 102 |
b["H_left"] = spans[ii]["x0"]
|
| 103 |
b["H_right"] = spans[ii]["x1"]
|
| 104 |
b["SP"] = ii
|
|
|
|
| 105 |
html = """
|
| 106 |
<html>
|
| 107 |
<head>
|
|
|
|
| 74 |
clmns = sorted([r for r in tb_cpns if re.match(
|
| 75 |
r"table column$", r["label"])], key=lambda x: x["x0"])
|
| 76 |
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
| 77 |
+
|
| 78 |
for b in boxes:
|
| 79 |
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
| 80 |
if ii is not None:
|
|
|
|
| 90 |
b["H_right"] = headers[ii]["x1"]
|
| 91 |
b["H"] = ii
|
| 92 |
|
| 93 |
+
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
| 94 |
if ii is not None:
|
| 95 |
b["C"] = ii
|
| 96 |
b["C_left"] = clmns[ii]["x0"]
|
|
|
|
| 103 |
b["H_left"] = spans[ii]["x0"]
|
| 104 |
b["H_right"] = spans[ii]["x1"]
|
| 105 |
b["SP"] = ii
|
| 106 |
+
|
| 107 |
html = """
|
| 108 |
<html>
|
| 109 |
<head>
|
deepdoc/vision/table_structure_recognizer.py
CHANGED
|
@@ -14,7 +14,6 @@ import logging
|
|
| 14 |
import os
|
| 15 |
import re
|
| 16 |
from collections import Counter
|
| 17 |
-
from copy import deepcopy
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
|
|
@@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer):
|
|
| 37 |
super().__init__(self.labels, "tsr",
|
| 38 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
| 39 |
|
| 40 |
-
def __call__(self, images, thr=0.
|
| 41 |
tbls = super().__call__(images, thr)
|
| 42 |
res = []
|
| 43 |
# align left&right for rows, align top&bottom for columns
|
|
@@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer):
|
|
| 56 |
"row") > 0 or b["label"].find("header") > 0]
|
| 57 |
if not left:
|
| 58 |
continue
|
| 59 |
-
left = np.
|
| 60 |
-
right = np.
|
| 61 |
for b in lts:
|
| 62 |
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
|
| 63 |
if b["x0"] > left:
|
|
@@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer):
|
|
| 129 |
i = 0
|
| 130 |
while i < len(boxes):
|
| 131 |
if TableStructureRecognizer.is_caption(boxes[i]):
|
|
|
|
| 132 |
cap += boxes[i]["text"]
|
| 133 |
boxes.pop(i)
|
| 134 |
i -= 1
|
|
@@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer):
|
|
| 398 |
for i in range(clmno):
|
| 399 |
if not tbl[r][i]:
|
| 400 |
continue
|
| 401 |
-
txt = "".join([a["text"].strip() for a in tbl[r][i]])
|
| 402 |
headers[r][i] = txt
|
| 403 |
hdrset.add(txt)
|
| 404 |
if all([not t for t in headers[r]]):
|
|
|
|
| 14 |
import os
|
| 15 |
import re
|
| 16 |
from collections import Counter
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
|
|
|
|
| 36 |
super().__init__(self.labels, "tsr",
|
| 37 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
| 38 |
|
| 39 |
+
def __call__(self, images, thr=0.2):
|
| 40 |
tbls = super().__call__(images, thr)
|
| 41 |
res = []
|
| 42 |
# align left&right for rows, align top&bottom for columns
|
|
|
|
| 55 |
"row") > 0 or b["label"].find("header") > 0]
|
| 56 |
if not left:
|
| 57 |
continue
|
| 58 |
+
left = np.mean(left) if len(left) > 4 else np.min(left)
|
| 59 |
+
right = np.mean(right) if len(right) > 4 else np.max(right)
|
| 60 |
for b in lts:
|
| 61 |
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
|
| 62 |
if b["x0"] > left:
|
|
|
|
| 128 |
i = 0
|
| 129 |
while i < len(boxes):
|
| 130 |
if TableStructureRecognizer.is_caption(boxes[i]):
|
| 131 |
+
if is_english: cap + " "
|
| 132 |
cap += boxes[i]["text"]
|
| 133 |
boxes.pop(i)
|
| 134 |
i -= 1
|
|
|
|
| 398 |
for i in range(clmno):
|
| 399 |
if not tbl[r][i]:
|
| 400 |
continue
|
| 401 |
+
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
|
| 402 |
headers[r][i] = txt
|
| 403 |
hdrset.add(txt)
|
| 404 |
if all([not t for t in headers[r]]):
|
rag/llm/chat_model.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
| 15 |
#
|
| 16 |
from abc import ABC
|
| 17 |
from openai import OpenAI
|
| 18 |
-
import
|
| 19 |
|
| 20 |
|
| 21 |
class Base(ABC):
|
|
@@ -33,11 +33,14 @@ class GptTurbo(Base):
|
|
| 33 |
|
| 34 |
def chat(self, system, history, gen_conf):
|
| 35 |
if system: history.insert(0, {"role": "system", "content": system})
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
from dashscope import Generation
|
|
@@ -58,7 +61,7 @@ class QWenChat(Base):
|
|
| 58 |
)
|
| 59 |
if response.status_code == HTTPStatus.OK:
|
| 60 |
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
| 61 |
-
return response.message, 0
|
| 62 |
|
| 63 |
|
| 64 |
from zhipuai import ZhipuAI
|
|
@@ -77,4 +80,4 @@ class ZhipuChat(Base):
|
|
| 77 |
)
|
| 78 |
if response.status_code == HTTPStatus.OK:
|
| 79 |
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
|
| 80 |
-
return response.message, 0
|
|
|
|
| 15 |
#
|
| 16 |
from abc import ABC
|
| 17 |
from openai import OpenAI
|
| 18 |
+
import openai
|
| 19 |
|
| 20 |
|
| 21 |
class Base(ABC):
|
|
|
|
| 33 |
|
| 34 |
def chat(self, system, history, gen_conf):
|
| 35 |
if system: history.insert(0, {"role": "system", "content": system})
|
| 36 |
+
try:
|
| 37 |
+
res = self.client.chat.completions.create(
|
| 38 |
+
model=self.model_name,
|
| 39 |
+
messages=history,
|
| 40 |
+
**gen_conf)
|
| 41 |
+
return res.choices[0].message.content.strip(), res.usage.completion_tokens
|
| 42 |
+
except openai.APIError as e:
|
| 43 |
+
return "ERROR: "+str(e), 0
|
| 44 |
|
| 45 |
|
| 46 |
from dashscope import Generation
|
|
|
|
| 61 |
)
|
| 62 |
if response.status_code == HTTPStatus.OK:
|
| 63 |
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
| 64 |
+
return "ERROR: " + response.message, 0
|
| 65 |
|
| 66 |
|
| 67 |
from zhipuai import ZhipuAI
|
|
|
|
| 80 |
)
|
| 81 |
if response.status_code == HTTPStatus.OK:
|
| 82 |
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
|
| 83 |
+
return "ERROR: " + response.message, 0
|
rag/nlp/__init__.py
CHANGED
|
@@ -1,7 +1,4 @@
|
|
| 1 |
-
from . import search
|
| 2 |
-
from rag.utils import ELASTICSEARCH
|
| 3 |
|
| 4 |
-
retrievaler = search.Dealer(ELASTICSEARCH)
|
| 5 |
|
| 6 |
from nltk.stem import PorterStemmer
|
| 7 |
stemmer = PorterStemmer()
|
|
@@ -39,10 +36,12 @@ BULLET_PATTERN = [[
|
|
| 39 |
]
|
| 40 |
]
|
| 41 |
|
|
|
|
| 42 |
def random_choices(arr, k):
|
| 43 |
k = min(len(arr), k)
|
| 44 |
return random.choices(arr, k=k)
|
| 45 |
|
|
|
|
| 46 |
def bullets_category(sections):
|
| 47 |
global BULLET_PATTERN
|
| 48 |
hits = [0] * len(BULLET_PATTERN)
|
|
|
|
|
|
|
|
|
|
| 1 |
|
|
|
|
| 2 |
|
| 3 |
from nltk.stem import PorterStemmer
|
| 4 |
stemmer = PorterStemmer()
|
|
|
|
| 36 |
]
|
| 37 |
]
|
| 38 |
|
| 39 |
+
|
| 40 |
def random_choices(arr, k):
|
| 41 |
k = min(len(arr), k)
|
| 42 |
return random.choices(arr, k=k)
|
| 43 |
|
| 44 |
+
|
| 45 |
def bullets_category(sections):
|
| 46 |
global BULLET_PATTERN
|
| 47 |
hits = [0] * len(BULLET_PATTERN)
|
rag/nlp/search.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
import json
|
| 3 |
import re
|
| 4 |
-
from elasticsearch_dsl import Q, Search
|
| 5 |
from typing import List, Optional, Dict, Union
|
| 6 |
from dataclasses import dataclass
|
| 7 |
|
|
@@ -183,6 +183,7 @@ class Dealer:
|
|
| 183 |
|
| 184 |
def insert_citations(self, answer, chunks, chunk_v,
|
| 185 |
embd_mdl, tkweight=0.3, vtweight=0.7):
|
|
|
|
| 186 |
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
|
| 187 |
for i in range(1, len(pieces)):
|
| 188 |
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
|
|
@@ -216,7 +217,7 @@ class Dealer:
|
|
| 216 |
if mx < 0.55:
|
| 217 |
continue
|
| 218 |
cites[idx[i]] = list(
|
| 219 |
-
set([str(
|
| 220 |
|
| 221 |
res = ""
|
| 222 |
for i, p in enumerate(pieces):
|
|
@@ -225,6 +226,7 @@ class Dealer:
|
|
| 225 |
continue
|
| 226 |
if i not in cites:
|
| 227 |
continue
|
|
|
|
| 228 |
res += "##%s$$" % "$".join(cites[i])
|
| 229 |
|
| 230 |
return res
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
import json
|
| 3 |
import re
|
| 4 |
+
from elasticsearch_dsl import Q, Search
|
| 5 |
from typing import List, Optional, Dict, Union
|
| 6 |
from dataclasses import dataclass
|
| 7 |
|
|
|
|
| 183 |
|
| 184 |
def insert_citations(self, answer, chunks, chunk_v,
|
| 185 |
embd_mdl, tkweight=0.3, vtweight=0.7):
|
| 186 |
+
assert len(chunks) == len(chunk_v)
|
| 187 |
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
|
| 188 |
for i in range(1, len(pieces)):
|
| 189 |
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
|
|
|
|
| 217 |
if mx < 0.55:
|
| 218 |
continue
|
| 219 |
cites[idx[i]] = list(
|
| 220 |
+
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
| 221 |
|
| 222 |
res = ""
|
| 223 |
for i, p in enumerate(pieces):
|
|
|
|
| 226 |
continue
|
| 227 |
if i not in cites:
|
| 228 |
continue
|
| 229 |
+
assert int(cites[i]) < len(chunk_v)
|
| 230 |
res += "##%s$$" % "$".join(cites[i])
|
| 231 |
|
| 232 |
return res
|