Kevin Hu commited on
Commit ·
1b1a5b7
1
Parent(s): faaabea
Support iframe chatbot. (#3961)
Browse files### What problem does this PR solve?
#3909
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- agent/canvas.py +4 -1
- agent/component/base.py +16 -1
- agent/component/generate.py +8 -0
- api/apps/canvas_app.py +20 -0
- api/apps/conversation_app.py +36 -16
- api/apps/sdk/session.py +10 -0
- api/db/services/canvas_service.py +16 -28
- api/db/services/conversation_service.py +21 -18
- api/db/services/dialog_service.py +42 -77
agent/canvas.py
CHANGED
|
@@ -330,4 +330,7 @@ class Canvas(ABC):
|
|
| 330 |
q["value"] = v
|
| 331 |
|
| 332 |
def get_preset_param(self):
|
| 333 |
-
return self.components["begin"]["obj"]._param.query
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
q["value"] = v
|
| 331 |
|
| 332 |
def get_preset_param(self):
|
| 333 |
+
return self.components["begin"]["obj"]._param.query
|
| 334 |
+
|
| 335 |
+
def get_component_input_elements(self, cpnnm):
|
| 336 |
+
return self.components["begin"]["obj"].get_input_elements()
|
agent/component/base.py
CHANGED
|
@@ -476,7 +476,7 @@ class ComponentBase(ABC):
|
|
| 476 |
self._param.inputs.append({"component_id": q["component_id"],
|
| 477 |
"content": "\n".join(
|
| 478 |
[str(d["content"]) for d in outs[-1].to_dict('records')])})
|
| 479 |
-
elif q
|
| 480 |
self._param.inputs.append({"component_id": None, "content": q["value"]})
|
| 481 |
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
| 482 |
if outs:
|
|
@@ -526,6 +526,21 @@ class ComponentBase(ABC):
|
|
| 526 |
|
| 527 |
return df
|
| 528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
def get_stream_input(self):
|
| 530 |
reversed_cpnts = []
|
| 531 |
if len(self._canvas.path) > 1:
|
|
|
|
| 476 |
self._param.inputs.append({"component_id": q["component_id"],
|
| 477 |
"content": "\n".join(
|
| 478 |
[str(d["content"]) for d in outs[-1].to_dict('records')])})
|
| 479 |
+
elif q.get("value"):
|
| 480 |
self._param.inputs.append({"component_id": None, "content": q["value"]})
|
| 481 |
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
| 482 |
if outs:
|
|
|
|
| 526 |
|
| 527 |
return df
|
| 528 |
|
| 529 |
+
def get_input_elements(self):
|
| 530 |
+
assert self._param.query, "Please identify input parameters firstly."
|
| 531 |
+
eles = []
|
| 532 |
+
for q in self._param.query:
|
| 533 |
+
if q.get("component_id"):
|
| 534 |
+
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
|
| 535 |
+
cpn_id, key = q["component_id"].split("@")
|
| 536 |
+
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
|
| 537 |
+
continue
|
| 538 |
+
|
| 539 |
+
eles.append({"key": q["key"], "component_id": q["component_id"]})
|
| 540 |
+
else:
|
| 541 |
+
eles.append({"key": q["key"]})
|
| 542 |
+
return eles
|
| 543 |
+
|
| 544 |
def get_stream_input(self):
|
| 545 |
reversed_cpnts = []
|
| 546 |
if len(self._canvas.path) > 1:
|
agent/component/generate.py
CHANGED
|
@@ -17,6 +17,7 @@ import re
|
|
| 17 |
from functools import partial
|
| 18 |
import pandas as pd
|
| 19 |
from api.db import LLMType
|
|
|
|
| 20 |
from api.db.services.dialog_service import message_fit_in
|
| 21 |
from api.db.services.llm_service import LLMBundle
|
| 22 |
from api import settings
|
|
@@ -104,9 +105,16 @@ class Generate(ComponentBase):
|
|
| 104 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 105 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
| 106 |
res = {"content": answer, "reference": reference}
|
|
|
|
| 107 |
|
| 108 |
return res
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def _run(self, history, **kwargs):
|
| 111 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
| 112 |
prompt = self._param.prompt
|
|
|
|
| 17 |
from functools import partial
|
| 18 |
import pandas as pd
|
| 19 |
from api.db import LLMType
|
| 20 |
+
from api.db.services.conversation_service import structure_answer
|
| 21 |
from api.db.services.dialog_service import message_fit_in
|
| 22 |
from api.db.services.llm_service import LLMBundle
|
| 23 |
from api import settings
|
|
|
|
| 105 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 106 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
| 107 |
res = {"content": answer, "reference": reference}
|
| 108 |
+
res = structure_answer(None, res, "", "")
|
| 109 |
|
| 110 |
return res
|
| 111 |
|
| 112 |
+
def get_input_elements(self):
|
| 113 |
+
if self._param.parameters:
|
| 114 |
+
return self._param.parameters
|
| 115 |
+
|
| 116 |
+
return [{"key": "input"}]
|
| 117 |
+
|
| 118 |
def _run(self, history, **kwargs):
|
| 119 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
| 120 |
prompt = self._param.prompt
|
api/apps/canvas_app.py
CHANGED
|
@@ -186,6 +186,26 @@ def reset():
|
|
| 186 |
return server_error_response(e)
|
| 187 |
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
| 190 |
@validate_request("db_type", "database", "username", "host", "port", "password")
|
| 191 |
@login_required
|
|
|
|
| 186 |
return server_error_response(e)
|
| 187 |
|
| 188 |
|
| 189 |
+
@manager.route('/input_elements', methods=['GET']) # noqa: F821
|
| 190 |
+
@validate_request("id", "component_id")
|
| 191 |
+
@login_required
|
| 192 |
+
def input_elements():
|
| 193 |
+
req = request.json
|
| 194 |
+
try:
|
| 195 |
+
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
| 196 |
+
if not e:
|
| 197 |
+
return get_data_error_result(message="canvas not found.")
|
| 198 |
+
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
| 199 |
+
return get_json_result(
|
| 200 |
+
data=False, message='Only owner of canvas authorized for this operation.',
|
| 201 |
+
code=RetCode.OPERATING_ERROR)
|
| 202 |
+
|
| 203 |
+
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
| 204 |
+
return get_json_result(data=canvas.get_component_input_elements(req["component_id"]))
|
| 205 |
+
except Exception as e:
|
| 206 |
+
return server_error_response(e)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
| 210 |
@validate_request("db_type", "database", "username", "host", "port", "password")
|
| 211 |
@login_required
|
api/apps/conversation_app.py
CHANGED
|
@@ -18,7 +18,7 @@ import re
|
|
| 18 |
import traceback
|
| 19 |
from copy import deepcopy
|
| 20 |
|
| 21 |
-
from api.db.services.conversation_service import ConversationService
|
| 22 |
from api.db.services.user_service import UserTenantService
|
| 23 |
from flask import request, Response
|
| 24 |
from flask_login import login_required, current_user
|
|
@@ -90,6 +90,21 @@ def get():
|
|
| 90 |
return get_json_result(
|
| 91 |
data=False, message='Only owner of conversation authorized for this operation.',
|
| 92 |
code=settings.RetCode.OPERATING_ERROR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
conv = conv.to_dict()
|
| 94 |
return get_json_result(data=conv)
|
| 95 |
except Exception as e:
|
|
@@ -132,6 +147,7 @@ def list_convsersation():
|
|
| 132 |
dialog_id=dialog_id,
|
| 133 |
order_by=ConversationService.model.create_time,
|
| 134 |
reverse=True)
|
|
|
|
| 135 |
convs = [d.to_dict() for d in convs]
|
| 136 |
return get_json_result(data=convs)
|
| 137 |
except Exception as e:
|
|
@@ -164,24 +180,29 @@ def completion():
|
|
| 164 |
|
| 165 |
if not conv.reference:
|
| 166 |
conv.reference = []
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
| 179 |
|
|
|
|
|
|
|
|
|
|
| 180 |
def stream():
|
| 181 |
nonlocal dia, msg, req, conv
|
| 182 |
try:
|
| 183 |
for ans in chat(dia, msg, True, **req):
|
| 184 |
-
|
| 185 |
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
| 186 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
| 187 |
except Exception as e:
|
|
@@ -202,8 +223,7 @@ def completion():
|
|
| 202 |
else:
|
| 203 |
answer = None
|
| 204 |
for ans in chat(dia, msg, **req):
|
| 205 |
-
answer = ans
|
| 206 |
-
fillin_conv(ans)
|
| 207 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
| 208 |
break
|
| 209 |
return get_json_result(data=answer)
|
|
|
|
| 18 |
import traceback
|
| 19 |
from copy import deepcopy
|
| 20 |
|
| 21 |
+
from api.db.services.conversation_service import ConversationService, structure_answer
|
| 22 |
from api.db.services.user_service import UserTenantService
|
| 23 |
from flask import request, Response
|
| 24 |
from flask_login import login_required, current_user
|
|
|
|
| 90 |
return get_json_result(
|
| 91 |
data=False, message='Only owner of conversation authorized for this operation.',
|
| 92 |
code=settings.RetCode.OPERATING_ERROR)
|
| 93 |
+
|
| 94 |
+
def get_value(d, k1, k2):
|
| 95 |
+
return d.get(k1, d.get(k2))
|
| 96 |
+
|
| 97 |
+
for ref in conv.reference:
|
| 98 |
+
ref["chunks"] = [{
|
| 99 |
+
"id": get_value(ck, "chunk_id", "id"),
|
| 100 |
+
"content": get_value(ck, "content", "content_with_weight"),
|
| 101 |
+
"document_id": get_value(ck, "doc_id", "document_id"),
|
| 102 |
+
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
| 103 |
+
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
| 104 |
+
"image_id": get_value(ck, "image_id", "img_id"),
|
| 105 |
+
"positions": get_value(ck, "positions", "position_int"),
|
| 106 |
+
} for ck in ref.get("chunks", [])]
|
| 107 |
+
|
| 108 |
conv = conv.to_dict()
|
| 109 |
return get_json_result(data=conv)
|
| 110 |
except Exception as e:
|
|
|
|
| 147 |
dialog_id=dialog_id,
|
| 148 |
order_by=ConversationService.model.create_time,
|
| 149 |
reverse=True)
|
| 150 |
+
|
| 151 |
convs = [d.to_dict() for d in convs]
|
| 152 |
return get_json_result(data=convs)
|
| 153 |
except Exception as e:
|
|
|
|
| 180 |
|
| 181 |
if not conv.reference:
|
| 182 |
conv.reference = []
|
| 183 |
+
else:
|
| 184 |
+
def get_value(d, k1, k2):
|
| 185 |
+
return d.get(k1, d.get(k2))
|
| 186 |
+
|
| 187 |
+
for ref in conv.reference:
|
| 188 |
+
ref["chunks"] = [{
|
| 189 |
+
"id": get_value(ck, "chunk_id", "id"),
|
| 190 |
+
"content": get_value(ck, "content", "content_with_weight"),
|
| 191 |
+
"document_id": get_value(ck, "doc_id", "document_id"),
|
| 192 |
+
"document_name": get_value(ck, "docnm_kwd", "document_name"),
|
| 193 |
+
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
|
| 194 |
+
"image_id": get_value(ck, "image_id", "img_id"),
|
| 195 |
+
"positions": get_value(ck, "positions", "position_int"),
|
| 196 |
+
} for ck in ref.get("chunks", [])]
|
| 197 |
|
| 198 |
+
if not conv.reference:
|
| 199 |
+
conv.reference = []
|
| 200 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
| 201 |
def stream():
|
| 202 |
nonlocal dia, msg, req, conv
|
| 203 |
try:
|
| 204 |
for ans in chat(dia, msg, True, **req):
|
| 205 |
+
ans = structure_answer(conv, ans, message_id, conv.id)
|
| 206 |
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
| 207 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
| 208 |
except Exception as e:
|
|
|
|
| 223 |
else:
|
| 224 |
answer = None
|
| 225 |
for ans in chat(dia, msg, **req):
|
| 226 |
+
answer = structure_answer(conv, ans, message_id, req["conversation_id"])
|
|
|
|
| 227 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
| 228 |
break
|
| 229 |
return get_json_result(data=answer)
|
api/apps/sdk/session.py
CHANGED
|
@@ -112,6 +112,11 @@ def update(tenant_id, chat_id, session_id):
|
|
| 112 |
@token_required
|
| 113 |
def chat_completion(tenant_id, chat_id):
|
| 114 |
req = request.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if req.get("stream", True):
|
| 116 |
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
| 117 |
resp.headers.add_header("Cache-control", "no-cache")
|
|
@@ -133,6 +138,11 @@ def chat_completion(tenant_id, chat_id):
|
|
| 133 |
@token_required
|
| 134 |
def agent_completions(tenant_id, agent_id):
|
| 135 |
req = request.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
if req.get("stream", True):
|
| 137 |
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
|
| 138 |
resp.headers.add_header("Cache-control", "no-cache")
|
|
|
|
| 112 |
@token_required
|
| 113 |
def chat_completion(tenant_id, chat_id):
|
| 114 |
req = request.json
|
| 115 |
+
if not DialogService.query(tenant_id=tenant_id,id=chat_id,status=StatusEnum.VALID.value):
|
| 116 |
+
return get_error_data_result(f"You don't own the chat {chat_id}")
|
| 117 |
+
if req.get("session_id"):
|
| 118 |
+
if not ConversationService.query(id=req["session_id"],dialog_id=chat_id):
|
| 119 |
+
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
| 120 |
if req.get("stream", True):
|
| 121 |
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
| 122 |
resp.headers.add_header("Cache-control", "no-cache")
|
|
|
|
| 138 |
@token_required
|
| 139 |
def agent_completions(tenant_id, agent_id):
|
| 140 |
req = request.json
|
| 141 |
+
if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
|
| 142 |
+
return get_error_data_result(f"You don't own the agent {agent_id}")
|
| 143 |
+
if req.get("session_id"):
|
| 144 |
+
if not API4ConversationService.query(id=req["session_id"],dialog_id=agent_id):
|
| 145 |
+
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
| 146 |
if req.get("stream", True):
|
| 147 |
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
|
| 148 |
resp.headers.add_header("Cache-control", "no-cache")
|
api/db/services/canvas_service.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import json
|
|
|
|
| 17 |
from uuid import uuid4
|
| 18 |
from agent.canvas import Canvas
|
| 19 |
from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
|
|
@@ -58,6 +59,8 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|
| 58 |
if not isinstance(cvs.dsl, str):
|
| 59 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
| 60 |
canvas = Canvas(cvs.dsl, tenant_id)
|
|
|
|
|
|
|
| 61 |
|
| 62 |
if not session_id:
|
| 63 |
session_id = get_uuid()
|
|
@@ -84,40 +87,24 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|
| 84 |
return
|
| 85 |
conv = API4Conversation(**conv)
|
| 86 |
else:
|
| 87 |
-
session_id = session_id
|
| 88 |
e, conv = API4ConversationService.get_by_id(session_id)
|
| 89 |
assert e, "Session not found!"
|
| 90 |
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
conv.message
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
if m["role"] == "system":
|
| 104 |
-
continue
|
| 105 |
-
if m["role"] == "assistant" and not msg:
|
| 106 |
-
continue
|
| 107 |
-
msg.append(m)
|
| 108 |
-
if not msg[-1].get("id"):
|
| 109 |
-
msg[-1]["id"] = get_uuid()
|
| 110 |
-
message_id = msg[-1]["id"]
|
| 111 |
-
|
| 112 |
-
if not conv.reference:
|
| 113 |
-
conv.reference = []
|
| 114 |
-
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
| 115 |
-
conv.reference.append({"chunks": [], "doc_aggs": []})
|
| 116 |
|
| 117 |
final_ans = {"reference": [], "content": ""}
|
| 118 |
|
| 119 |
-
canvas.add_user_input(msg[-1]["content"])
|
| 120 |
-
|
| 121 |
if stream:
|
| 122 |
try:
|
| 123 |
for ans in canvas.run(stream=stream):
|
|
@@ -141,6 +128,7 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|
| 141 |
conv.dsl = json.loads(str(canvas))
|
| 142 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 143 |
except Exception as e:
|
|
|
|
| 144 |
conv.dsl = json.loads(str(canvas))
|
| 145 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 146 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import json
|
| 17 |
+
import traceback
|
| 18 |
from uuid import uuid4
|
| 19 |
from agent.canvas import Canvas
|
| 20 |
from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation
|
|
|
|
| 59 |
if not isinstance(cvs.dsl, str):
|
| 60 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
| 61 |
canvas = Canvas(cvs.dsl, tenant_id)
|
| 62 |
+
canvas.reset()
|
| 63 |
+
message_id = str(uuid4())
|
| 64 |
|
| 65 |
if not session_id:
|
| 66 |
session_id = get_uuid()
|
|
|
|
| 87 |
return
|
| 88 |
conv = API4Conversation(**conv)
|
| 89 |
else:
|
|
|
|
| 90 |
e, conv = API4ConversationService.get_by_id(session_id)
|
| 91 |
assert e, "Session not found!"
|
| 92 |
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
|
| 93 |
+
canvas.messages.append({"role": "user", "content": question, "id": message_id})
|
| 94 |
+
canvas.add_user_input(question)
|
| 95 |
+
if not conv.message:
|
| 96 |
+
conv.message = []
|
| 97 |
+
conv.message.append({
|
| 98 |
+
"role": "user",
|
| 99 |
+
"content": question,
|
| 100 |
+
"id": message_id
|
| 101 |
+
})
|
| 102 |
+
if not conv.reference:
|
| 103 |
+
conv.reference = []
|
| 104 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
final_ans = {"reference": [], "content": ""}
|
| 107 |
|
|
|
|
|
|
|
| 108 |
if stream:
|
| 109 |
try:
|
| 110 |
for ans in canvas.run(stream=stream):
|
|
|
|
| 128 |
conv.dsl = json.loads(str(canvas))
|
| 129 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 130 |
except Exception as e:
|
| 131 |
+
traceback.print_exc()
|
| 132 |
conv.dsl = json.loads(str(canvas))
|
| 133 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 134 |
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
api/db/services/conversation_service.py
CHANGED
|
@@ -21,7 +21,6 @@ from api.db.services.common_service import CommonService
|
|
| 21 |
from api.db.services.dialog_service import DialogService, chat
|
| 22 |
from api.utils import get_uuid
|
| 23 |
import json
|
| 24 |
-
from copy import deepcopy
|
| 25 |
|
| 26 |
|
| 27 |
class ConversationService(CommonService):
|
|
@@ -49,30 +48,35 @@ def structure_answer(conv, ans, message_id, session_id):
|
|
| 49 |
reference = ans["reference"]
|
| 50 |
if not isinstance(reference, dict):
|
| 51 |
reference = {}
|
| 52 |
-
|
| 53 |
-
if not conv.reference:
|
| 54 |
-
conv.reference.append(temp_reference)
|
| 55 |
-
else:
|
| 56 |
-
conv.reference[-1] = temp_reference
|
| 57 |
-
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
| 58 |
|
|
|
|
|
|
|
| 59 |
chunk_list = [{
|
| 60 |
-
"id": chunk
|
| 61 |
-
"content":
|
| 62 |
-
"document_id": chunk
|
| 63 |
-
"document_name": chunk
|
| 64 |
-
"dataset_id": chunk
|
| 65 |
-
"image_id": chunk
|
| 66 |
-
"
|
| 67 |
-
"vector_similarity": chunk["vector_similarity"],
|
| 68 |
-
"term_similarity": chunk["term_similarity"],
|
| 69 |
-
"positions": chunk["positions"],
|
| 70 |
} for chunk in reference.get("chunks", [])]
|
| 71 |
|
| 72 |
reference["chunks"] = chunk_list
|
| 73 |
ans["id"] = message_id
|
| 74 |
ans["session_id"] = session_id
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
return ans
|
| 77 |
|
| 78 |
|
|
@@ -199,7 +203,6 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|
| 199 |
|
| 200 |
if not conv.reference:
|
| 201 |
conv.reference = []
|
| 202 |
-
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
| 203 |
conv.reference.append({"chunks": [], "doc_aggs": []})
|
| 204 |
|
| 205 |
if stream:
|
|
|
|
| 21 |
from api.db.services.dialog_service import DialogService, chat
|
| 22 |
from api.utils import get_uuid
|
| 23 |
import json
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class ConversationService(CommonService):
|
|
|
|
| 48 |
reference = ans["reference"]
|
| 49 |
if not isinstance(reference, dict):
|
| 50 |
reference = {}
|
| 51 |
+
ans["reference"] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
def get_value(d, k1, k2):
|
| 54 |
+
return d.get(k1, d.get(k2))
|
| 55 |
chunk_list = [{
|
| 56 |
+
"id": get_value(chunk, "chunk_id", "id"),
|
| 57 |
+
"content": get_value(chunk, "content", "content_with_weight"),
|
| 58 |
+
"document_id": get_value(chunk, "doc_id", "document_id"),
|
| 59 |
+
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
| 60 |
+
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
| 61 |
+
"image_id": get_value(chunk, "image_id", "img_id"),
|
| 62 |
+
"positions": get_value(chunk, "positions", "position_int"),
|
|
|
|
|
|
|
|
|
|
| 63 |
} for chunk in reference.get("chunks", [])]
|
| 64 |
|
| 65 |
reference["chunks"] = chunk_list
|
| 66 |
ans["id"] = message_id
|
| 67 |
ans["session_id"] = session_id
|
| 68 |
|
| 69 |
+
if not conv:
|
| 70 |
+
return ans
|
| 71 |
+
|
| 72 |
+
if not conv.message:
|
| 73 |
+
conv.message = []
|
| 74 |
+
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
| 75 |
+
conv.message.append({"role": "assistant", "content": ans["answer"], "id": message_id})
|
| 76 |
+
else:
|
| 77 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
| 78 |
+
if conv.reference:
|
| 79 |
+
conv.reference[-1] = reference
|
| 80 |
return ans
|
| 81 |
|
| 82 |
|
|
|
|
| 203 |
|
| 204 |
if not conv.reference:
|
| 205 |
conv.reference = []
|
|
|
|
| 206 |
conv.reference.append({"chunks": [], "doc_aggs": []})
|
| 207 |
|
| 208 |
if stream:
|
api/db/services/dialog_service.py
CHANGED
|
@@ -18,6 +18,7 @@ import binascii
|
|
| 18 |
import os
|
| 19 |
import json
|
| 20 |
import re
|
|
|
|
| 21 |
from copy import deepcopy
|
| 22 |
from timeit import default_timer as timer
|
| 23 |
import datetime
|
|
@@ -108,6 +109,32 @@ def llm_id2llm_type(llm_id):
|
|
| 108 |
return llm["model_type"].strip(",")[-1]
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 112 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 113 |
st = timer()
|
|
@@ -195,32 +222,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 195 |
dialog.vector_similarity_weight,
|
| 196 |
doc_ids=attachments,
|
| 197 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
| 198 |
-
|
| 199 |
-
# Group chunks by document ID
|
| 200 |
-
doc_chunks = {}
|
| 201 |
-
for ck in kbinfos["chunks"]:
|
| 202 |
-
doc_id = ck["doc_id"]
|
| 203 |
-
if doc_id not in doc_chunks:
|
| 204 |
-
doc_chunks[doc_id] = []
|
| 205 |
-
doc_chunks[doc_id].append(ck["content_with_weight"])
|
| 206 |
-
|
| 207 |
-
# Create knowledges list with grouped chunks
|
| 208 |
-
knowledges = []
|
| 209 |
-
for doc_id, chunks in doc_chunks.items():
|
| 210 |
-
# Find the corresponding document name
|
| 211 |
-
doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
|
| 212 |
-
|
| 213 |
-
# Create a header for the document
|
| 214 |
-
doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
|
| 215 |
-
|
| 216 |
-
# Add numbered fragments
|
| 217 |
-
for i, chunk in enumerate(chunks, 1):
|
| 218 |
-
doc_knowledge += f"{i}. {chunk}\n"
|
| 219 |
-
|
| 220 |
-
knowledges.append(doc_knowledge)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
logging.debug(
|
| 225 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 226 |
retrieval_tm = timer()
|
|
@@ -603,7 +605,6 @@ def tts(tts_mdl, text):
|
|
| 603 |
|
| 604 |
def ask(question, kb_ids, tenant_id):
|
| 605 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
| 606 |
-
tenant_ids = [kb.tenant_id for kb in kbs]
|
| 607 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 608 |
|
| 609 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
|
@@ -612,45 +613,9 @@ def ask(question, kb_ids, tenant_id):
|
|
| 612 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
| 613 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
| 614 |
max_tokens = chat_mdl.max_length
|
| 615 |
-
|
| 616 |
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
| 617 |
-
knowledges =
|
| 618 |
-
|
| 619 |
-
used_token_count = 0
|
| 620 |
-
chunks_num = 0
|
| 621 |
-
for i, c in enumerate(knowledges):
|
| 622 |
-
used_token_count += num_tokens_from_string(c)
|
| 623 |
-
if max_tokens * 0.97 < used_token_count:
|
| 624 |
-
knowledges = knowledges[:i]
|
| 625 |
-
chunks_num = chunks_num + 1
|
| 626 |
-
break
|
| 627 |
-
|
| 628 |
-
# Group chunks by document ID
|
| 629 |
-
doc_chunks = {}
|
| 630 |
-
counter_chunks = 0
|
| 631 |
-
for ck in kbinfos["chunks"]:
|
| 632 |
-
if counter_chunks < chunks_num:
|
| 633 |
-
counter_chunks = counter_chunks + 1
|
| 634 |
-
doc_id = ck["doc_id"]
|
| 635 |
-
if doc_id not in doc_chunks:
|
| 636 |
-
doc_chunks[doc_id] = []
|
| 637 |
-
doc_chunks[doc_id].append(ck["content_with_weight"])
|
| 638 |
-
|
| 639 |
-
# Create knowledges list with grouped chunks
|
| 640 |
-
knowledges = []
|
| 641 |
-
for doc_id, chunks in doc_chunks.items():
|
| 642 |
-
# Find the corresponding document name
|
| 643 |
-
doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id)
|
| 644 |
-
|
| 645 |
-
# Create a header for the document
|
| 646 |
-
doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n"
|
| 647 |
-
|
| 648 |
-
# Add numbered fragments
|
| 649 |
-
for i, chunk in enumerate(chunks, 1):
|
| 650 |
-
doc_knowledge += f"{i}. {chunk}\n"
|
| 651 |
-
|
| 652 |
-
knowledges.append(doc_knowledge)
|
| 653 |
-
|
| 654 |
prompt = """
|
| 655 |
Role: You're a smart assistant. Your name is Miss R.
|
| 656 |
Task: Summarize the information from knowledge bases and answer user's question.
|
|
@@ -660,25 +625,25 @@ def ask(question, kb_ids, tenant_id):
|
|
| 660 |
- Answer with markdown format text.
|
| 661 |
- Answer in language of user's question.
|
| 662 |
- DO NOT make things up, especially for numbers.
|
| 663 |
-
|
| 664 |
### Information from knowledge bases
|
| 665 |
%s
|
| 666 |
-
|
| 667 |
The above is information from knowledge bases.
|
| 668 |
-
|
| 669 |
-
"""%"\n".join(knowledges)
|
| 670 |
msg = [{"role": "user", "content": question}]
|
| 671 |
|
| 672 |
def decorate_answer(answer):
|
| 673 |
nonlocal knowledges, kbinfos, prompt
|
| 674 |
answer, idx = retr.insert_citations(answer,
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 683 |
recall_docs = [
|
| 684 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
@@ -691,7 +656,7 @@ def ask(question, kb_ids, tenant_id):
|
|
| 691 |
del c["vector"]
|
| 692 |
|
| 693 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 694 |
-
answer += " Please set LLM API-Key in 'User Setting -> Model
|
| 695 |
return {"answer": answer, "reference": refs}
|
| 696 |
|
| 697 |
answer = ""
|
|
|
|
| 18 |
import os
|
| 19 |
import json
|
| 20 |
import re
|
| 21 |
+
from collections import defaultdict
|
| 22 |
from copy import deepcopy
|
| 23 |
from timeit import default_timer as timer
|
| 24 |
import datetime
|
|
|
|
| 109 |
return llm["model_type"].strip(",")[-1]
|
| 110 |
|
| 111 |
|
| 112 |
+
def kb_prompt(kbinfos, max_tokens):
|
| 113 |
+
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 114 |
+
used_token_count = 0
|
| 115 |
+
chunks_num = 0
|
| 116 |
+
for i, c in enumerate(knowledges):
|
| 117 |
+
used_token_count += num_tokens_from_string(c)
|
| 118 |
+
chunks_num += 1
|
| 119 |
+
if max_tokens * 0.97 < used_token_count:
|
| 120 |
+
knowledges = knowledges[:i]
|
| 121 |
+
break
|
| 122 |
+
|
| 123 |
+
doc2chunks = defaultdict(list)
|
| 124 |
+
for i, ck in enumerate(kbinfos["chunks"]):
|
| 125 |
+
if i >= chunks_num:
|
| 126 |
+
break
|
| 127 |
+
doc2chunks["docnm_kwd"].append(ck["content_with_weight"])
|
| 128 |
+
|
| 129 |
+
knowledges = []
|
| 130 |
+
for nm, chunks in doc2chunks.items():
|
| 131 |
+
txt = f"Document: {nm} \nContains the following relevant fragments:\n"
|
| 132 |
+
for i, chunk in enumerate(chunks, 1):
|
| 133 |
+
txt += f"{i}. {chunk}\n"
|
| 134 |
+
knowledges.append(txt)
|
| 135 |
+
return knowledges
|
| 136 |
+
|
| 137 |
+
|
| 138 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 139 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 140 |
st = timer()
|
|
|
|
| 222 |
dialog.vector_similarity_weight,
|
| 223 |
doc_ids=attachments,
|
| 224 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
| 225 |
+
knowledges = kb_prompt(kbinfos, max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
logging.debug(
|
| 227 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 228 |
retrieval_tm = timer()
|
|
|
|
| 605 |
|
| 606 |
def ask(question, kb_ids, tenant_id):
|
| 607 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
|
|
|
| 608 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 609 |
|
| 610 |
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
|
|
|
| 613 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
| 614 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
| 615 |
max_tokens = chat_mdl.max_length
|
| 616 |
+
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
| 617 |
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
| 618 |
+
knowledges = kb_prompt(kbinfos, max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
prompt = """
|
| 620 |
Role: You're a smart assistant. Your name is Miss R.
|
| 621 |
Task: Summarize the information from knowledge bases and answer user's question.
|
|
|
|
| 625 |
- Answer with markdown format text.
|
| 626 |
- Answer in language of user's question.
|
| 627 |
- DO NOT make things up, especially for numbers.
|
| 628 |
+
|
| 629 |
### Information from knowledge bases
|
| 630 |
%s
|
| 631 |
+
|
| 632 |
The above is information from knowledge bases.
|
| 633 |
+
|
| 634 |
+
""" % "\n".join(knowledges)
|
| 635 |
msg = [{"role": "user", "content": question}]
|
| 636 |
|
| 637 |
def decorate_answer(answer):
|
| 638 |
nonlocal knowledges, kbinfos, prompt
|
| 639 |
answer, idx = retr.insert_citations(answer,
|
| 640 |
+
[ck["content_ltks"]
|
| 641 |
+
for ck in kbinfos["chunks"]],
|
| 642 |
+
[ck["vector"]
|
| 643 |
+
for ck in kbinfos["chunks"]],
|
| 644 |
+
embd_mdl,
|
| 645 |
+
tkweight=0.7,
|
| 646 |
+
vtweight=0.3)
|
| 647 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 648 |
recall_docs = [
|
| 649 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
|
|
| 656 |
del c["vector"]
|
| 657 |
|
| 658 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 659 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 660 |
return {"answer": answer, "reference": refs}
|
| 661 |
|
| 662 |
answer = ""
|