maxiaolong03
commited on
Commit
·
98d3121
1
Parent(s):
3a5faf4
add files
Browse files- app.py +247 -164
- bot_requests.py +40 -38
app.py
CHANGED
|
@@ -12,11 +12,14 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
-
"""
|
|
|
|
|
|
|
| 16 |
|
| 17 |
import argparse
|
| 18 |
import base64
|
| 19 |
from collections import namedtuple
|
|
|
|
| 20 |
from functools import partial
|
| 21 |
import hashlib
|
| 22 |
import json
|
|
@@ -25,12 +28,12 @@ import faiss
|
|
| 25 |
import os
|
| 26 |
from argparse import ArgumentParser
|
| 27 |
import textwrap
|
|
|
|
| 28 |
|
| 29 |
import gradio as gr
|
| 30 |
import numpy as np
|
| 31 |
|
| 32 |
from bot_requests import BotClient
|
| 33 |
-
# from faiss_text_database import FaissTextDatabase
|
| 34 |
|
| 35 |
os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
|
| 36 |
|
|
@@ -44,89 +47,105 @@ RELEVANT_PASSAGE_DEFAULT = textwrap.dedent("""\
|
|
| 44 |
)
|
| 45 |
|
| 46 |
QUERY_REWRITE_PROMPT = textwrap.dedent("""\
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
1.
|
| 55 |
-
2.
|
| 56 |
-
3.
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
```
|
| 64 |
{{
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
```"""
|
| 69 |
)
|
| 70 |
-
|
| 71 |
ANSWER_PROMPT = textwrap.dedent(
|
| 72 |
"""\
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
-
|
|
|
|
| 83 |
)
|
| 84 |
QUERY_DEFAULT = "1675 年时,英格兰有多少家咖啡馆?"
|
| 85 |
|
| 86 |
|
| 87 |
def get_args() -> argparse.Namespace:
|
| 88 |
"""
|
| 89 |
-
Parse and return command line arguments for the ERNIE models
|
| 90 |
-
Configures server settings, model
|
| 91 |
|
| 92 |
Returns:
|
| 93 |
-
argparse.Namespace: Parsed command line arguments containing
|
| 94 |
-
- server_port: Demo server port (default: 8333)
|
| 95 |
-
- server_name: Demo server host (default: "0.0.0.0")
|
| 96 |
-
- model_urls: Endpoints for ERNIE and Qianfan models
|
| 97 |
-
- document_processing: Chunk size, FAISS index and text DB paths
|
| 98 |
"""
|
| 99 |
parser = ArgumentParser(description="ERNIE models web chat demo.")
|
| 100 |
|
| 101 |
parser.add_argument(
|
| 102 |
-
"--server-port", type=int, default=
|
| 103 |
)
|
| 104 |
parser.add_argument(
|
| 105 |
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
|
| 106 |
)
|
| 107 |
parser.add_argument(
|
| 108 |
-
"--max_char", type=int, default=
|
| 109 |
)
|
| 110 |
parser.add_argument(
|
| 111 |
"--max_retry_num", type=int, default=3, help="Maximum retry number for request."
|
| 112 |
)
|
| 113 |
parser.add_argument(
|
| 114 |
-
"--
|
| 115 |
type=str,
|
| 116 |
-
default="https://qianfan.baidubce.com/v2",
|
| 117 |
-
help="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
)
|
| 119 |
parser.add_argument(
|
| 120 |
-
"--
|
| 121 |
type=str,
|
| 122 |
default="https://qianfan.baidubce.com/v2",
|
| 123 |
-
help="
|
| 124 |
)
|
| 125 |
parser.add_argument(
|
| 126 |
"--qianfan_api_key",
|
| 127 |
type=str,
|
| 128 |
default=os.environ.get("API_KEY"),
|
| 129 |
-
help="Qianfan API key."
|
| 130 |
)
|
| 131 |
parser.add_argument(
|
| 132 |
"--embedding_model",
|
|
@@ -134,12 +153,24 @@ def get_args() -> argparse.Namespace:
|
|
| 134 |
default="embedding-v1",
|
| 135 |
help="Embedding model name."
|
| 136 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
parser.add_argument(
|
| 138 |
"--chunk_size",
|
| 139 |
type=int,
|
| 140 |
default=512,
|
| 141 |
help="Chunk size for splitting long documents."
|
| 142 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
parser.add_argument(
|
| 144 |
"--faiss_index_path",
|
| 145 |
type=str,
|
|
@@ -154,15 +185,24 @@ def get_args() -> argparse.Namespace:
|
|
| 154 |
)
|
| 155 |
|
| 156 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
return args
|
| 158 |
|
| 159 |
|
| 160 |
class FaissTextDatabase:
|
| 161 |
"""
|
| 162 |
-
A vector database for text retrieval using FAISS
|
| 163 |
Provides efficient similarity search and document management capabilities.
|
| 164 |
"""
|
| 165 |
-
def __init__(self, args, bot_client: BotClient
|
| 166 |
"""
|
| 167 |
Initialize the FaissTextDatabase.
|
| 168 |
|
|
@@ -174,9 +214,11 @@ class FaissTextDatabase:
|
|
| 174 |
self.logger = logging.getLogger(__name__)
|
| 175 |
|
| 176 |
self.bot_client = bot_client
|
|
|
|
|
|
|
|
|
|
| 177 |
self.faiss_index_path = getattr(args, "faiss_index_path", "data/faiss_index")
|
| 178 |
self.text_db_path = getattr(args, "text_db_path", "data/text_db.jsonl")
|
| 179 |
-
self.embedding_dim = embedding_dim
|
| 180 |
|
| 181 |
# If faiss_index_path exists, load it and text_db_path
|
| 182 |
if os.path.exists(self.faiss_index_path) and os.path.exists(self.text_db_path):
|
|
@@ -216,7 +258,8 @@ class FaissTextDatabase:
|
|
| 216 |
file_md5 = self.calculate_md5(file_path)
|
| 217 |
return file_md5 in self.text_db["file_md5s"]
|
| 218 |
|
| 219 |
-
def add_embeddings(self, file_path: str, segments: list[str], progress_bar: gr.Progress=None
|
|
|
|
| 220 |
"""
|
| 221 |
Stores document embeddings in FAISS database after checking for duplicates.
|
| 222 |
Generates embeddings for each text segment, updates the FAISS index and metadata database,
|
|
@@ -241,8 +284,9 @@ class FaissTextDatabase:
|
|
| 241 |
# Generate embeddings
|
| 242 |
vectors = []
|
| 243 |
file_name = os.path.basename(file_path)
|
|
|
|
| 244 |
for i, segment in enumerate(segments):
|
| 245 |
-
vectors.append(self.bot_client.embed_fn(segment))
|
| 246 |
if progress_bar is not None:
|
| 247 |
progress_bar((i + 1) / len(segments), desc=file_name + " Processing...")
|
| 248 |
vectors = np.array(vectors)
|
|
@@ -252,43 +296,87 @@ class FaissTextDatabase:
|
|
| 252 |
for i, text in enumerate(segments):
|
| 253 |
self.text_db["chunks"].append({
|
| 254 |
"file_md5": file_md5,
|
|
|
|
|
|
|
| 255 |
"text": text,
|
| 256 |
"vector_id": start_id + i
|
| 257 |
})
|
| 258 |
|
| 259 |
self.text_db["file_md5s"].append(file_md5)
|
| 260 |
-
|
|
|
|
| 261 |
return True
|
| 262 |
|
| 263 |
-
def search_with_context(self,
|
| 264 |
"""
|
| 265 |
-
Finds the most relevant text
|
| 266 |
-
Uses FAISS to find the closest matching
|
| 267 |
from the same source document to provide better context understanding.
|
| 268 |
-
|
| 269 |
Args:
|
| 270 |
-
|
| 271 |
-
context_size: the number of surrounding chunks to include
|
| 272 |
|
| 273 |
Returns:
|
| 274 |
-
str: the
|
| 275 |
-
"""
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
self.logger.info("
|
| 283 |
-
self.logger.info("Target Chunk: {}".format(self.text_db["chunks"][target_idx]["text"]))
|
| 284 |
|
| 285 |
-
#
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
result = ""
|
| 289 |
-
for
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
| 292 |
|
| 293 |
return result
|
| 294 |
|
|
@@ -305,13 +393,35 @@ class GradioEvents(object):
|
|
| 305 |
Manages event handling and UI interactions for Gradio applications.
|
| 306 |
Provides methods to process user inputs, trigger callbacks, and update interface components.
|
| 307 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
@staticmethod
|
| 309 |
def chat_stream(
|
| 310 |
query: str,
|
| 311 |
task_history: list,
|
| 312 |
model: str,
|
| 313 |
-
bot_client: BotClient,
|
| 314 |
faiss_db: FaissTextDatabase,
|
|
|
|
| 315 |
) -> dict:
|
| 316 |
"""
|
| 317 |
Streams chatbot responses by processing queries with context from history and FAISS database.
|
|
@@ -328,23 +438,29 @@ class GradioEvents(object):
|
|
| 328 |
Yields:
|
| 329 |
dict: A dictionary containing the event type and its corresponding content.
|
| 330 |
"""
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
)
|
| 337 |
-
yield {"type": "relevant_passage", "content": relevant_passage}
|
| 338 |
-
input = ANSWER_PROMPT.format(query=query, relevant_passage=relevant_passage)
|
| 339 |
-
else:
|
| 340 |
-
input = query
|
| 341 |
|
| 342 |
-
conversation
|
| 343 |
-
for query_h, response_h in task_history:
|
| 344 |
-
conversation.append({"role": "user", "content": query_h})
|
| 345 |
-
conversation.append({"role": "assistant", "content": response_h})
|
| 346 |
-
conversation.append({"role": "user", "content": input})
|
| 347 |
-
|
| 348 |
try:
|
| 349 |
req_data = {"messages": conversation}
|
| 350 |
for chunk in bot_client.process_stream(model, req_data):
|
|
@@ -353,10 +469,7 @@ class GradioEvents(object):
|
|
| 353 |
|
| 354 |
message = chunk.get("choices", [{}])[0].get("delta", {})
|
| 355 |
content = message.get("content", "")
|
| 356 |
-
reasoning_content = message.get("reasoning_content", "")
|
| 357 |
|
| 358 |
-
if reasoning_content:
|
| 359 |
-
yield {"type": "thinking", "content": reasoning_content}
|
| 360 |
if content:
|
| 361 |
yield {"type": "answer", "content": content}
|
| 362 |
|
|
@@ -369,8 +482,8 @@ class GradioEvents(object):
|
|
| 369 |
chatbot: list,
|
| 370 |
task_history: list,
|
| 371 |
model: str,
|
|
|
|
| 372 |
bot_client: BotClient,
|
| 373 |
-
faiss_db: FaissTextDatabase
|
| 374 |
) -> tuple:
|
| 375 |
"""
|
| 376 |
Generates streaming responses by combining model predictions with knowledge retrieval.
|
|
@@ -400,12 +513,11 @@ class GradioEvents(object):
|
|
| 400 |
query,
|
| 401 |
task_history,
|
| 402 |
model,
|
| 403 |
-
bot_client,
|
| 404 |
faiss_db,
|
|
|
|
| 405 |
)
|
| 406 |
-
|
| 407 |
response = ""
|
| 408 |
-
has_thinking = False
|
| 409 |
current_relevant_passage = None
|
| 410 |
for new_text in new_texts:
|
| 411 |
if not isinstance(new_text, dict):
|
|
@@ -419,27 +531,15 @@ class GradioEvents(object):
|
|
| 419 |
current_relevant_passage = new_text["content"]
|
| 420 |
yield chatbot, current_relevant_passage
|
| 421 |
continue
|
| 422 |
-
elif new_text.get("type") == "thinking":
|
| 423 |
-
has_thinking = True
|
| 424 |
-
reasoning_content += new_text["content"]
|
| 425 |
elif new_text.get("type") == "answer":
|
| 426 |
response += new_text["content"]
|
| 427 |
|
| 428 |
-
# Remove previous
|
| 429 |
if chatbot[-1].get("role") == "assistant":
|
| 430 |
chatbot.pop(-1)
|
| 431 |
|
| 432 |
-
content = ""
|
| 433 |
-
if has_thinking:
|
| 434 |
-
content = "**思考过程:**<br>{}<br>".format(reasoning_content)
|
| 435 |
if response:
|
| 436 |
-
|
| 437 |
-
content += "<br><br>**最终回答:**<br>{}".format(response)
|
| 438 |
-
else:
|
| 439 |
-
content = response
|
| 440 |
-
|
| 441 |
-
if content:
|
| 442 |
-
chatbot.append({"role": "assistant", "content": content})
|
| 443 |
yield chatbot, current_relevant_passage
|
| 444 |
|
| 445 |
logging.info("History: {}".format(task_history))
|
|
@@ -451,8 +551,8 @@ class GradioEvents(object):
|
|
| 451 |
chatbot: list,
|
| 452 |
task_history: list,
|
| 453 |
model: str,
|
|
|
|
| 454 |
bot_client: BotClient,
|
| 455 |
-
faiss_db: FaissTextDatabase
|
| 456 |
) -> tuple:
|
| 457 |
"""
|
| 458 |
Regenerate the chatbot's response based on the latest user query
|
|
@@ -481,8 +581,8 @@ class GradioEvents(object):
|
|
| 481 |
chatbot,
|
| 482 |
task_history,
|
| 483 |
model,
|
|
|
|
| 484 |
bot_client,
|
| 485 |
-
faiss_db
|
| 486 |
):
|
| 487 |
yield chunk, relevant_passage
|
| 488 |
|
|
@@ -548,44 +648,20 @@ class GradioEvents(object):
|
|
| 548 |
return url
|
| 549 |
|
| 550 |
@staticmethod
|
| 551 |
-
def
|
| 552 |
-
sub_query_list: list,
|
| 553 |
-
faiss_db: FaissTextDatabase
|
| 554 |
-
) -> str:
|
| 555 |
-
"""
|
| 556 |
-
Retrieve the relevant passage from the database based on the query.
|
| 557 |
-
|
| 558 |
-
Args:
|
| 559 |
-
sub_query_list (list): List of sub-queries.
|
| 560 |
-
faiss_db (FaissTextDatabase): The FAISS database instance.
|
| 561 |
-
|
| 562 |
-
Returns:
|
| 563 |
-
str: The relevant passage.
|
| 564 |
-
"""
|
| 565 |
-
relevant_passages = ""
|
| 566 |
-
for idx, query_item in enumerate(sub_query_list):
|
| 567 |
-
relevant_passage = faiss_db.search_with_context(query_item)
|
| 568 |
-
relevant_passages += "\n段落{idx}:\n{relevant_passage}".format(idx=idx + 1, relevant_passage=relevant_passage)
|
| 569 |
-
|
| 570 |
-
return relevant_passages
|
| 571 |
-
|
| 572 |
-
@staticmethod
|
| 573 |
-
def get_sub_query(query: str, model_name: str, bot_client: BotClient) -> dict:
|
| 574 |
"""
|
| 575 |
Enhances user queries by generating alternative phrasings using language models.
|
| 576 |
Creates semantically similar variations of the original query to improve retrieval accuracy.
|
| 577 |
Returns structured dictionary containing both original and rephrased queries.
|
| 578 |
|
| 579 |
Args:
|
| 580 |
-
|
| 581 |
model_name (str): The name of the model to use for rephrasing.
|
| 582 |
bot_client (BotClient): The bot client instance.
|
| 583 |
|
| 584 |
Returns:
|
| 585 |
dict: The rephrased query.
|
| 586 |
"""
|
| 587 |
-
query = QUERY_REWRITE_PROMPT.format(query=query)
|
| 588 |
-
conversation = [{"role": "user", "content": query}]
|
| 589 |
req_data = {"messages": conversation}
|
| 590 |
try:
|
| 591 |
response = bot_client.process(model_name, req_data)
|
|
@@ -600,7 +676,8 @@ class GradioEvents(object):
|
|
| 600 |
search_info_res["sub_query_list"] = unique_list
|
| 601 |
return search_info_res
|
| 602 |
except Exception:
|
| 603 |
-
|
|
|
|
| 604 |
|
| 605 |
@staticmethod
|
| 606 |
def split_oversized_line(line: str, chunk_size: int) -> tuple:
|
|
@@ -615,7 +692,7 @@ class GradioEvents(object):
|
|
| 615 |
Returns:
|
| 616 |
tuple: Two strings, the first part of the original line and the rest of the line.
|
| 617 |
"""
|
| 618 |
-
PUNCTUATIONS =
|
| 619 |
|
| 620 |
if len(line) <= chunk_size:
|
| 621 |
return line, ""
|
|
@@ -636,28 +713,33 @@ class GradioEvents(object):
|
|
| 636 |
return line[:split_pos], line[split_pos:]
|
| 637 |
|
| 638 |
@staticmethod
|
| 639 |
-
def split_text_into_chunks(
|
| 640 |
"""
|
| 641 |
-
Split text into chunks of a specified size while respecting natural language boundaries
|
| 642 |
and avoiding mid-word splits whenever possible.
|
| 643 |
|
| 644 |
Args:
|
| 645 |
-
|
| 646 |
chunk_size (int): The maximum length of each chunk.
|
| 647 |
|
| 648 |
Returns:
|
| 649 |
list: A list of strings, where each element represents a chunk of the original text.
|
| 650 |
"""
|
| 651 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
chunks = []
|
| 653 |
current_chunk = []
|
| 654 |
current_length = 0
|
| 655 |
|
| 656 |
for line in lines:
|
| 657 |
-
|
| 658 |
# If adding this line would exceed chunk size (and we have content)
|
| 659 |
if current_length + len(line) > chunk_size and current_chunk:
|
| 660 |
-
chunks.append("
|
| 661 |
current_chunk = []
|
| 662 |
current_length = 0
|
| 663 |
|
|
@@ -672,7 +754,7 @@ class GradioEvents(object):
|
|
| 672 |
current_length += len(line) + 1
|
| 673 |
|
| 674 |
if current_chunk:
|
| 675 |
-
chunks.append("
|
| 676 |
return chunks
|
| 677 |
|
| 678 |
@staticmethod
|
|
@@ -706,7 +788,8 @@ class GradioEvents(object):
|
|
| 706 |
yield gr.update(visible=False)
|
| 707 |
|
| 708 |
@staticmethod
|
| 709 |
-
def save_file_to_db(file_url: str, chunk_size: int, faiss_db: FaissTextDatabase,
|
|
|
|
| 710 |
"""
|
| 711 |
Processes and indexes document content into FAISS database with semantic-aware chunking.
|
| 712 |
Handles file validation, text segmentation, embedding generation and storage operations.
|
|
@@ -720,14 +803,16 @@ class GradioEvents(object):
|
|
| 720 |
Returns:
|
| 721 |
bool: True if the file was saved successfully, otherwise False.
|
| 722 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
file_name = os.path.basename(file_url)
|
| 724 |
if not faiss_db.is_file_processed(file_url):
|
| 725 |
logging.info("{} not processed yet, processing now...".format(file_url))
|
| 726 |
try:
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
segments = GradioEvents.split_text_into_chunks(text, chunk_size)
|
| 730 |
-
faiss_db.add_embeddings(file_url, segments, progress_bar)
|
| 731 |
|
| 732 |
logging.info("{} processed successfully.".format(file_url))
|
| 733 |
return True
|
|
@@ -740,7 +825,7 @@ class GradioEvents(object):
|
|
| 740 |
return False
|
| 741 |
|
| 742 |
|
| 743 |
-
def launch_demo(args: argparse.Namespace, bot_client: BotClient,
|
| 744 |
"""
|
| 745 |
Launch demo program
|
| 746 |
|
|
@@ -770,7 +855,8 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient, faiss_db: Faiss
|
|
| 770 |
}
|
| 771 |
"""
|
| 772 |
with gr.Blocks(css=css) as demo:
|
| 773 |
-
model_name = gr.State(
|
|
|
|
| 774 |
|
| 775 |
logo_url = GradioEvents.get_image_url("assets/logo.png")
|
| 776 |
gr.Markdown("""\
|
|
@@ -816,35 +902,32 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient, faiss_db: Faiss
|
|
| 816 |
|
| 817 |
predict_with_clients = partial(
|
| 818 |
GradioEvents.predict_stream,
|
| 819 |
-
bot_client=bot_client
|
| 820 |
-
faiss_db=faiss_db
|
| 821 |
)
|
| 822 |
regenerate_with_clients = partial(
|
| 823 |
GradioEvents.regenerate,
|
| 824 |
-
bot_client=bot_client
|
| 825 |
-
faiss_db=faiss_db
|
| 826 |
)
|
| 827 |
file_upload_with_clients = partial(
|
| 828 |
GradioEvents.file_upload,
|
| 829 |
-
faiss_db=faiss_db
|
| 830 |
)
|
| 831 |
|
| 832 |
chunk_size = gr.State(args.chunk_size)
|
| 833 |
file_btn.change(
|
| 834 |
fn=file_upload_with_clients,
|
| 835 |
-
inputs=[file_btn, chunk_size],
|
| 836 |
outputs=[progress_bar],
|
| 837 |
)
|
| 838 |
query.submit(
|
| 839 |
predict_with_clients,
|
| 840 |
-
inputs=[query, chatbot, task_history, model_name],
|
| 841 |
outputs=[chatbot, relevant_passage],
|
| 842 |
show_progress=True
|
| 843 |
)
|
| 844 |
query.submit(GradioEvents.reset_user_input, [], [query])
|
| 845 |
submit_btn.click(
|
| 846 |
predict_with_clients,
|
| 847 |
-
inputs=[query, chatbot, task_history, model_name],
|
| 848 |
outputs=[chatbot, relevant_passage],
|
| 849 |
show_progress=True,
|
| 850 |
)
|
|
@@ -855,7 +938,7 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient, faiss_db: Faiss
|
|
| 855 |
)
|
| 856 |
regen_btn.click(
|
| 857 |
regenerate_with_clients,
|
| 858 |
-
inputs=[chatbot, task_history, model_name],
|
| 859 |
outputs=[chatbot, relevant_passage],
|
| 860 |
show_progress=True
|
| 861 |
)
|
|
@@ -873,7 +956,7 @@ def main():
|
|
| 873 |
faiss_db = FaissTextDatabase(args, bot_client)
|
| 874 |
|
| 875 |
# Run file upload function to save default knowledge base.
|
| 876 |
-
GradioEvents.save_file_to_db(FILE_URL_DEFAULT, args.chunk_size, faiss_db)
|
| 877 |
|
| 878 |
launch_demo(args, bot_client, faiss_db)
|
| 879 |
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
+
"""
|
| 16 |
+
This script provides a Gradio interface for interacting with a chatbot based on Retrieval-Augmented Generation.
|
| 17 |
+
"""
|
| 18 |
|
| 19 |
import argparse
|
| 20 |
import base64
|
| 21 |
from collections import namedtuple
|
| 22 |
+
from datetime import datetime
|
| 23 |
from functools import partial
|
| 24 |
import hashlib
|
| 25 |
import json
|
|
|
|
| 28 |
import os
|
| 29 |
from argparse import ArgumentParser
|
| 30 |
import textwrap
|
| 31 |
+
import copy
|
| 32 |
|
| 33 |
import gradio as gr
|
| 34 |
import numpy as np
|
| 35 |
|
| 36 |
from bot_requests import BotClient
|
|
|
|
| 37 |
|
| 38 |
os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
|
| 39 |
|
|
|
|
| 47 |
)
|
| 48 |
|
| 49 |
QUERY_REWRITE_PROMPT = textwrap.dedent("""\
|
| 50 |
+
【当前时间】
|
| 51 |
+
{TIMESTAMP}
|
| 52 |
+
|
| 53 |
+
【对话内容】
|
| 54 |
+
{CONVERSATION}
|
| 55 |
+
|
| 56 |
+
你的任务是根据上面user与assistant的对话内容,理解user意图,改写user的最后一轮对话,以便更高效地从知识库查找相关知识。具体的改写要求如下:
|
| 57 |
+
1. 如果user的问题包括几个小问题,请将它们分成多个单独的问题。
|
| 58 |
+
2. 如果user的问题涉及到之前对话的信息,请将这些信息融入问题中,形成一个不需要上下文就可以理解的完整问题。
|
| 59 |
+
3. 如果user的问题是在比较或关联多个事物时,先将其拆分为单个事物的问题,例如‘A与B比起来怎么样’,拆分为:‘A怎么样’以及‘B怎么样’。
|
| 60 |
+
4. 如果user的问题中描述事物的限定词有多个,请将多个限定词拆分成单个限定词。
|
| 61 |
+
5. 如果user的问题具有**时效性(需要包含当前时间信息,才能得到正确的回复)**的时候,需要将当前时间信息添加到改写的query中;否则不加入当前时间信息。
|
| 62 |
+
6. 只在**确有必要**的情况下改写,不需要改写时query输出[]。输出不超过 5 个改写问题,不要为了凑满数量而输出冗余问题。
|
| 63 |
+
|
| 64 |
+
【输出格式】只输出 JSON ,不要给出多余内容
|
| 65 |
+
```json
|
|
|
|
| 66 |
{{
|
| 67 |
+
"query": ["改写问题1", "改写问题2"...]
|
| 68 |
+
}}```
|
| 69 |
+
"""
|
|
|
|
| 70 |
)
|
|
|
|
| 71 |
ANSWER_PROMPT = textwrap.dedent(
|
| 72 |
"""\
|
| 73 |
+
你是阅读理解问答专家。
|
| 74 |
+
|
| 75 |
+
【文档知识】
|
| 76 |
+
{DOC_CONTENT}
|
| 77 |
+
|
| 78 |
+
你的任务是根据对话内容,理解用户需求,参考文档知识回答用户问题,知识参考详细原则如下:
|
| 79 |
+
- 对于同一信息点,如文档知识与模型通用知识均可支撑,应优先以文档知识为主,并对信息进行验证和综合。
|
| 80 |
+
- 如果文档知识不足或信息冲突,必须指出“根据资料无法确定”或“不同资料存在矛盾”,不得引入文档知识与通识之外的主观推测。
|
| 81 |
+
|
| 82 |
+
同时,回答问题需要综合考虑规则要求中的各项内容,详细要求如下:
|
| 83 |
+
【规则要求】
|
| 84 |
+
* 回答问题时,应优先参考与问题紧密相关的文档知识,不要在答案中引入任何与问题无关的文档内容。
|
| 85 |
+
* 回答中不可以让用户知道你查询了相关文档。
|
| 86 |
+
* 回复答案不要出现'根据文档知识','根据当前时间'等表述。
|
| 87 |
+
* 论述突出重点内容,以分点条理清晰的结构化格式输出。
|
| 88 |
|
| 89 |
+
【当前时间】
|
| 90 |
+
{TIMESTAMP}
|
| 91 |
|
| 92 |
+
【对话内容】
|
| 93 |
+
{CONVERSATION}
|
| 94 |
|
| 95 |
+
直接输出回复内容即可。
|
| 96 |
+
"""
|
| 97 |
)
|
| 98 |
QUERY_DEFAULT = "1675 年时,英格兰有多少家咖啡馆?"
|
| 99 |
|
| 100 |
|
| 101 |
def get_args() -> argparse.Namespace:
|
| 102 |
"""
|
| 103 |
+
Parse and return command line arguments for the ERNIE models chat demo.
|
| 104 |
+
Configures server settings, model endpoint, and document processing parameters.
|
| 105 |
|
| 106 |
Returns:
|
| 107 |
+
argparse.Namespace: Parsed command line arguments containing all the above settings.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
"""
|
| 109 |
parser = ArgumentParser(description="ERNIE models web chat demo.")
|
| 110 |
|
| 111 |
parser.add_argument(
|
| 112 |
+
"--server-port", type=int, default=8686, help="Demo server port."
|
| 113 |
)
|
| 114 |
parser.add_argument(
|
| 115 |
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
|
| 116 |
)
|
| 117 |
parser.add_argument(
|
| 118 |
+
"--max_char", type=int, default=20000, help="Maximum character limit for messages."
|
| 119 |
)
|
| 120 |
parser.add_argument(
|
| 121 |
"--max_retry_num", type=int, default=3, help="Maximum retry number for request."
|
| 122 |
)
|
| 123 |
parser.add_argument(
|
| 124 |
+
"--model_map",
|
| 125 |
type=str,
|
| 126 |
+
default="{\"ernie-4.5-turbo-vl-32k\": \"https://qianfan.baidubce.com/v2\"}",
|
| 127 |
+
help="""JSON string defining model name to endpoint mappings.
|
| 128 |
+
Required Format:
|
| 129 |
+
{"ERNIE-4.5": "http://localhost:port/v1"}
|
| 130 |
+
|
| 131 |
+
Note:
|
| 132 |
+
- Endpoints must be valid HTTP URL
|
| 133 |
+
- Specify ONE model endpoint in JSON format.
|
| 134 |
+
- Prefix determines model capabilities:
|
| 135 |
+
* ERNIE-4.5: Text-only model
|
| 136 |
+
"""
|
| 137 |
)
|
| 138 |
parser.add_argument(
|
| 139 |
+
"--embedding_service_url",
|
| 140 |
type=str,
|
| 141 |
default="https://qianfan.baidubce.com/v2",
|
| 142 |
+
help="Embedding service url."
|
| 143 |
)
|
| 144 |
parser.add_argument(
|
| 145 |
"--qianfan_api_key",
|
| 146 |
type=str,
|
| 147 |
default=os.environ.get("API_KEY"),
|
| 148 |
+
help="Qianfan API key.",
|
| 149 |
)
|
| 150 |
parser.add_argument(
|
| 151 |
"--embedding_model",
|
|
|
|
| 153 |
default="embedding-v1",
|
| 154 |
help="Embedding model name."
|
| 155 |
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--embedding_dim",
|
| 158 |
+
type=int,
|
| 159 |
+
default=384,
|
| 160 |
+
help="Dimension of the embedding vector."
|
| 161 |
+
)
|
| 162 |
parser.add_argument(
|
| 163 |
"--chunk_size",
|
| 164 |
type=int,
|
| 165 |
default=512,
|
| 166 |
help="Chunk size for splitting long documents."
|
| 167 |
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--top_k",
|
| 170 |
+
type=int,
|
| 171 |
+
default=3,
|
| 172 |
+
help="Top k results to retrieve."
|
| 173 |
+
)
|
| 174 |
parser.add_argument(
|
| 175 |
"--faiss_index_path",
|
| 176 |
type=str,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
args = parser.parse_args()
|
| 188 |
+
try:
|
| 189 |
+
args.model_map = json.loads(args.model_map)
|
| 190 |
+
|
| 191 |
+
# Validation: Check at least one model exists
|
| 192 |
+
if len(args.model_map) < 1:
|
| 193 |
+
raise ValueError("model_map must contain at least one model configuration")
|
| 194 |
+
except json.JSONDecodeError as e:
|
| 195 |
+
raise ValueError("Invalid JSON format for --model-map") from e
|
| 196 |
+
|
| 197 |
return args
|
| 198 |
|
| 199 |
|
| 200 |
class FaissTextDatabase:
|
| 201 |
"""
|
| 202 |
+
A vector database for text retrieval using FAISS.
|
| 203 |
Provides efficient similarity search and document management capabilities.
|
| 204 |
"""
|
| 205 |
+
def __init__(self, args, bot_client: BotClient):
|
| 206 |
"""
|
| 207 |
Initialize the FaissTextDatabase.
|
| 208 |
|
|
|
|
| 214 |
self.logger = logging.getLogger(__name__)
|
| 215 |
|
| 216 |
self.bot_client = bot_client
|
| 217 |
+
self.embedding_dim = getattr(args, "embedding_dim", 384)
|
| 218 |
+
self.top_k = getattr(args, "top_k", 3)
|
| 219 |
+
self.context_size = getattr(args, "context_size", 2)
|
| 220 |
self.faiss_index_path = getattr(args, "faiss_index_path", "data/faiss_index")
|
| 221 |
self.text_db_path = getattr(args, "text_db_path", "data/text_db.jsonl")
|
|
|
|
| 222 |
|
| 223 |
# If faiss_index_path exists, load it and text_db_path
|
| 224 |
if os.path.exists(self.faiss_index_path) and os.path.exists(self.text_db_path):
|
|
|
|
| 258 |
file_md5 = self.calculate_md5(file_path)
|
| 259 |
return file_md5 in self.text_db["file_md5s"]
|
| 260 |
|
| 261 |
+
def add_embeddings(self, file_path: str, segments: list[str], progress_bar: gr.Progress=None, \
|
| 262 |
+
save_file: bool=False) -> bool:
|
| 263 |
"""
|
| 264 |
Stores document embeddings in FAISS database after checking for duplicates.
|
| 265 |
Generates embeddings for each text segment, updates the FAISS index and metadata database,
|
|
|
|
| 284 |
# Generate embeddings
|
| 285 |
vectors = []
|
| 286 |
file_name = os.path.basename(file_path)
|
| 287 |
+
file_txt = "".join(file_name.split(".")[:-1])[:30]
|
| 288 |
for i, segment in enumerate(segments):
|
| 289 |
+
vectors.append(self.bot_client.embed_fn(file_txt + "\n" + segment))
|
| 290 |
if progress_bar is not None:
|
| 291 |
progress_bar((i + 1) / len(segments), desc=file_name + " Processing...")
|
| 292 |
vectors = np.array(vectors)
|
|
|
|
| 296 |
for i, text in enumerate(segments):
|
| 297 |
self.text_db["chunks"].append({
|
| 298 |
"file_md5": file_md5,
|
| 299 |
+
"file_name": file_name,
|
| 300 |
+
"file_txt": file_txt,
|
| 301 |
"text": text,
|
| 302 |
"vector_id": start_id + i
|
| 303 |
})
|
| 304 |
|
| 305 |
self.text_db["file_md5s"].append(file_md5)
|
| 306 |
+
if save_file:
|
| 307 |
+
self.save()
|
| 308 |
return True
|
| 309 |
|
| 310 |
+
def search_with_context(self, query_list: list) -> str:
|
| 311 |
"""
|
| 312 |
+
Finds the most relevant text chunks for multiple queries and includes surrounding context.
|
| 313 |
+
Uses FAISS to find the closest matching embeddings, then retrieves adjacent chunks
|
| 314 |
from the same source document to provide better context understanding.
|
| 315 |
+
|
| 316 |
Args:
|
| 317 |
+
query_list: list of input query strings
|
|
|
|
| 318 |
|
| 319 |
Returns:
|
| 320 |
+
str: the concatenated output string
|
| 321 |
+
"""
|
| 322 |
+
# Step 1: Retrieve top_k results for each query and collect all indices
|
| 323 |
+
all_indices = []
|
| 324 |
+
for query in query_list:
|
| 325 |
+
query_vector = np.array([self.bot_client.embed_fn(query)]).astype('float32')
|
| 326 |
+
_, indices = self.index.search(query_vector, self.top_k)
|
| 327 |
+
all_indices.extend(indices[0].tolist())
|
| 328 |
|
| 329 |
+
# Step 2: Remove duplicate indices
|
| 330 |
+
unique_indices = sorted(list(set(all_indices)))
|
| 331 |
+
self.logger.info(f"Retrieved indices: {all_indices}")
|
| 332 |
+
self.logger.info(f"Unique indices after deduplication: {unique_indices}")
|
|
|
|
| 333 |
|
| 334 |
+
# Step 3: Expand each index with context (within same file boundaries)
|
| 335 |
+
expanded_indices = set()
|
| 336 |
+
file_boundaries = {} # {file_md5: (start_idx, end_idx)}
|
| 337 |
+
for target_idx in unique_indices:
|
| 338 |
+
target_chunk = self.text_db["chunks"][target_idx]
|
| 339 |
+
target_file_md5 = target_chunk["file_md5"]
|
| 340 |
+
|
| 341 |
+
if target_file_md5 not in file_boundaries:
|
| 342 |
+
file_start = target_idx
|
| 343 |
+
while file_start > 0 and self.text_db["chunks"][file_start - 1]["file_md5"] == target_file_md5:
|
| 344 |
+
file_start -= 1
|
| 345 |
+
file_end = target_idx
|
| 346 |
+
while (file_end < len(self.text_db["chunks"]) - 1 and
|
| 347 |
+
self.text_db["chunks"][file_end + 1]["file_md5"] == target_file_md5):
|
| 348 |
+
file_end += 1
|
| 349 |
+
else:
|
| 350 |
+
file_start, file_end = file_boundaries[target_file_md5]
|
| 351 |
+
|
| 352 |
+
# Calculate context range within file boundaries
|
| 353 |
+
start = max(file_start, target_idx - self.context_size)
|
| 354 |
+
end = min(file_end, target_idx + self.context_size)
|
| 355 |
+
|
| 356 |
+
for pos in range(start, end + 1):
|
| 357 |
+
expanded_indices.add(pos)
|
| 358 |
+
|
| 359 |
+
# Step 4: Sort and merge continuous chunks
|
| 360 |
+
sorted_indices = sorted(list(expanded_indices))
|
| 361 |
+
groups = []
|
| 362 |
+
current_group = [sorted_indices[0]]
|
| 363 |
+
for i in range(1, len(sorted_indices)):
|
| 364 |
+
if (sorted_indices[i] == sorted_indices[i - 1] + 1 and
|
| 365 |
+
self.text_db["chunks"][sorted_indices[i]]["file_md5"] ==
|
| 366 |
+
self.text_db["chunks"][sorted_indices[i - 1]]["file_md5"]):
|
| 367 |
+
current_group.append(sorted_indices[i])
|
| 368 |
+
else:
|
| 369 |
+
groups.append(current_group)
|
| 370 |
+
current_group = [sorted_indices[i]]
|
| 371 |
+
groups.append(current_group)
|
| 372 |
+
|
| 373 |
+
# Step 5: Create merged text for each group
|
| 374 |
result = ""
|
| 375 |
+
for idx, group in enumerate(groups):
|
| 376 |
+
result += "\n段落{idx}:\n{title}\n".format(idx=idx + 1, title=self.text_db["chunks"][group[0]]["file_txt"])
|
| 377 |
+
for idx in group:
|
| 378 |
+
result += self.text_db["chunks"][idx]["text"] + "\n"
|
| 379 |
+
self.logger.info(f"Merged chunk range: {group[0]}-{group[-1]}")
|
| 380 |
|
| 381 |
return result
|
| 382 |
|
|
|
|
| 393 |
Manages event handling and UI interactions for Gradio applications.
|
| 394 |
Provides methods to process user inputs, trigger callbacks, and update interface components.
|
| 395 |
"""
|
| 396 |
+
@staticmethod
|
| 397 |
+
def get_history_conversation(task_history: list) -> tuple:
|
| 398 |
+
"""
|
| 399 |
+
Converts task history into conversation format for model processing.
|
| 400 |
+
Transforms query-response pairs into structured message history and plain text.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
task_history (list): List of tuples containing queries and responses.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
tuple: Tuple containing two elements:
|
| 407 |
+
- conversation (list): List of dictionaries representing the conversation history.
|
| 408 |
+
- conversation_str (str): String representation of the conversation history.
|
| 409 |
+
"""
|
| 410 |
+
conversation = []
|
| 411 |
+
conversation_str = ""
|
| 412 |
+
for query_h, response_h in task_history:
|
| 413 |
+
conversation.append({"role": "user", "content": query_h})
|
| 414 |
+
conversation.append({"role": "assistant", "content": response_h})
|
| 415 |
+
conversation_str += "user:\n{query}\n assistant:\n{response}\n ".format(query=query_h, response=response_h)
|
| 416 |
+
return conversation, conversation_str
|
| 417 |
+
|
| 418 |
@staticmethod
|
| 419 |
def chat_stream(
|
| 420 |
query: str,
|
| 421 |
task_history: list,
|
| 422 |
model: str,
|
|
|
|
| 423 |
faiss_db: FaissTextDatabase,
|
| 424 |
+
bot_client: BotClient,
|
| 425 |
) -> dict:
|
| 426 |
"""
|
| 427 |
Streams chatbot responses by processing queries with context from history and FAISS database.
|
|
|
|
| 438 |
Yields:
|
| 439 |
dict: A dictionary containing the event type and its corresponding content.
|
| 440 |
"""
|
| 441 |
+
conversation, conversation_str = GradioEvents.get_history_conversation(task_history)
|
| 442 |
+
conversation_str += "user:\n{query}\n".format(query=query)
|
| 443 |
+
|
| 444 |
+
search_info_message = QUERY_REWRITE_PROMPT.format(
|
| 445 |
+
TIMESTAMP=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 446 |
+
CONVERSATION=conversation_str
|
| 447 |
+
)
|
| 448 |
+
search_conversation = [{"role": "user", "content": search_info_message}]
|
| 449 |
+
search_info_result = GradioEvents.get_sub_query(search_conversation, model, bot_client)
|
| 450 |
+
if search_info_result is None:
|
| 451 |
+
search_info_result = {"query": [query]}
|
| 452 |
+
|
| 453 |
+
if search_info_result.get("query", []):
|
| 454 |
+
relevant_passages = faiss_db.search_with_context(search_info_result["query"])
|
| 455 |
+
yield {"type": "relevant_passage", "content": relevant_passages}
|
| 456 |
+
|
| 457 |
+
query = ANSWER_PROMPT.format(
|
| 458 |
+
DOC_CONTENT=relevant_passages,
|
| 459 |
+
TIMESTAMP=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 460 |
+
CONVERSATION=conversation_str
|
| 461 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
|
| 463 |
+
conversation.append({"role": "user", "content": query})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
try:
|
| 465 |
req_data = {"messages": conversation}
|
| 466 |
for chunk in bot_client.process_stream(model, req_data):
|
|
|
|
| 469 |
|
| 470 |
message = chunk.get("choices", [{}])[0].get("delta", {})
|
| 471 |
content = message.get("content", "")
|
|
|
|
| 472 |
|
|
|
|
|
|
|
| 473 |
if content:
|
| 474 |
yield {"type": "answer", "content": content}
|
| 475 |
|
|
|
|
| 482 |
chatbot: list,
|
| 483 |
task_history: list,
|
| 484 |
model: str,
|
| 485 |
+
faiss_db: FaissTextDatabase,
|
| 486 |
bot_client: BotClient,
|
|
|
|
| 487 |
) -> tuple:
|
| 488 |
"""
|
| 489 |
Generates streaming responses by combining model predictions with knowledge retrieval.
|
|
|
|
| 513 |
query,
|
| 514 |
task_history,
|
| 515 |
model,
|
|
|
|
| 516 |
faiss_db,
|
| 517 |
+
bot_client,
|
| 518 |
)
|
| 519 |
+
|
| 520 |
response = ""
|
|
|
|
| 521 |
current_relevant_passage = None
|
| 522 |
for new_text in new_texts:
|
| 523 |
if not isinstance(new_text, dict):
|
|
|
|
| 531 |
current_relevant_passage = new_text["content"]
|
| 532 |
yield chatbot, current_relevant_passage
|
| 533 |
continue
|
|
|
|
|
|
|
|
|
|
| 534 |
elif new_text.get("type") == "answer":
|
| 535 |
response += new_text["content"]
|
| 536 |
|
| 537 |
+
# Remove previous message if exists
|
| 538 |
if chatbot[-1].get("role") == "assistant":
|
| 539 |
chatbot.pop(-1)
|
| 540 |
|
|
|
|
|
|
|
|
|
|
| 541 |
if response:
|
| 542 |
+
chatbot.append({"role": "assistant", "content": response})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
yield chatbot, current_relevant_passage
|
| 544 |
|
| 545 |
logging.info("History: {}".format(task_history))
|
|
|
|
| 551 |
chatbot: list,
|
| 552 |
task_history: list,
|
| 553 |
model: str,
|
| 554 |
+
faiss_db: FaissTextDatabase,
|
| 555 |
bot_client: BotClient,
|
|
|
|
| 556 |
) -> tuple:
|
| 557 |
"""
|
| 558 |
Regenerate the chatbot's response based on the latest user query
|
|
|
|
| 581 |
chatbot,
|
| 582 |
task_history,
|
| 583 |
model,
|
| 584 |
+
faiss_db,
|
| 585 |
bot_client,
|
|
|
|
| 586 |
):
|
| 587 |
yield chunk, relevant_passage
|
| 588 |
|
|
|
|
| 648 |
return url
|
| 649 |
|
| 650 |
@staticmethod
|
| 651 |
+
def get_sub_query(conversation: list, model_name: str, bot_client: BotClient) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
"""
|
| 653 |
Enhances user queries by generating alternative phrasings using language models.
|
| 654 |
Creates semantically similar variations of the original query to improve retrieval accuracy.
|
| 655 |
Returns structured dictionary containing both original and rephrased queries.
|
| 656 |
|
| 657 |
Args:
|
| 658 |
+
conversation (list): The conversation history.
|
| 659 |
model_name (str): The name of the model to use for rephrasing.
|
| 660 |
bot_client (BotClient): The bot client instance.
|
| 661 |
|
| 662 |
Returns:
|
| 663 |
dict: The rephrased query.
|
| 664 |
"""
|
|
|
|
|
|
|
| 665 |
req_data = {"messages": conversation}
|
| 666 |
try:
|
| 667 |
response = bot_client.process(model_name, req_data)
|
|
|
|
| 676 |
search_info_res["sub_query_list"] = unique_list
|
| 677 |
return search_info_res
|
| 678 |
except Exception:
|
| 679 |
+
logging.error("Error: Model output is not a valid JSON")
|
| 680 |
+
return None
|
| 681 |
|
| 682 |
@staticmethod
|
| 683 |
def split_oversized_line(line: str, chunk_size: int) -> tuple:
|
|
|
|
| 692 |
Returns:
|
| 693 |
tuple: Two strings, the first part of the original line and the rest of the line.
|
| 694 |
"""
|
| 695 |
+
PUNCTUATIONS = {".", "。", "!", "!", "?", "?", ",", ",", ";", ";", ":", ":"}
|
| 696 |
|
| 697 |
if len(line) <= chunk_size:
|
| 698 |
return line, ""
|
|
|
|
| 713 |
return line[:split_pos], line[split_pos:]
|
| 714 |
|
| 715 |
@staticmethod
|
| 716 |
+
def split_text_into_chunks(file_url: str, chunk_size: int) -> list:
|
| 717 |
"""
|
| 718 |
+
Split file text into chunks of a specified size while respecting natural language boundaries
|
| 719 |
and avoiding mid-word splits whenever possible.
|
| 720 |
|
| 721 |
Args:
|
| 722 |
+
file_url (str): The file URL.
|
| 723 |
chunk_size (int): The maximum length of each chunk.
|
| 724 |
|
| 725 |
Returns:
|
| 726 |
list: A list of strings, where each element represents a chunk of the original text.
|
| 727 |
"""
|
| 728 |
+
with open(file_url, "r", encoding="utf-8") as f:
|
| 729 |
+
text = f.read()
|
| 730 |
+
|
| 731 |
+
if not text:
|
| 732 |
+
logging.error("Error: File is empty")
|
| 733 |
+
return []
|
| 734 |
+
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
| 735 |
chunks = []
|
| 736 |
current_chunk = []
|
| 737 |
current_length = 0
|
| 738 |
|
| 739 |
for line in lines:
|
|
|
|
| 740 |
# If adding this line would exceed chunk size (and we have content)
|
| 741 |
if current_length + len(line) > chunk_size and current_chunk:
|
| 742 |
+
chunks.append("\n".join(current_chunk))
|
| 743 |
current_chunk = []
|
| 744 |
current_length = 0
|
| 745 |
|
|
|
|
| 754 |
current_length += len(line) + 1
|
| 755 |
|
| 756 |
if current_chunk:
|
| 757 |
+
chunks.append("\n".join(current_chunk))
|
| 758 |
return chunks
|
| 759 |
|
| 760 |
@staticmethod
|
|
|
|
| 788 |
yield gr.update(visible=False)
|
| 789 |
|
| 790 |
@staticmethod
|
| 791 |
+
def save_file_to_db(file_url: str, chunk_size: int, faiss_db: FaissTextDatabase, \
|
| 792 |
+
progress_bar: gr.Progress=None, save_file: bool=False):
|
| 793 |
"""
|
| 794 |
Processes and indexes document content into FAISS database with semantic-aware chunking.
|
| 795 |
Handles file validation, text segmentation, embedding generation and storage operations.
|
|
|
|
| 803 |
Returns:
|
| 804 |
bool: True if the file was saved successfully, otherwise False.
|
| 805 |
"""
|
| 806 |
+
if not os.path.exists(file_url):
|
| 807 |
+
logging.error("File not found: {}".format(file_url))
|
| 808 |
+
return False
|
| 809 |
+
|
| 810 |
file_name = os.path.basename(file_url)
|
| 811 |
if not faiss_db.is_file_processed(file_url):
|
| 812 |
logging.info("{} not processed yet, processing now...".format(file_url))
|
| 813 |
try:
|
| 814 |
+
segments = GradioEvents.split_text_into_chunks(file_url, chunk_size)
|
| 815 |
+
faiss_db.add_embeddings(file_url, segments, progress_bar, save_file)
|
|
|
|
|
|
|
| 816 |
|
| 817 |
logging.info("{} processed successfully.".format(file_url))
|
| 818 |
return True
|
|
|
|
| 825 |
return False
|
| 826 |
|
| 827 |
|
| 828 |
+
def launch_demo(args: argparse.Namespace, bot_client: BotClient, faiss_db_template: FaissTextDatabase):
|
| 829 |
"""
|
| 830 |
Launch demo program
|
| 831 |
|
|
|
|
| 855 |
}
|
| 856 |
"""
|
| 857 |
with gr.Blocks(css=css) as demo:
|
| 858 |
+
model_name = gr.State(list(args.model_map.keys())[0])
|
| 859 |
+
faiss_db = gr.State(copy.deepcopy(faiss_db_template))
|
| 860 |
|
| 861 |
logo_url = GradioEvents.get_image_url("assets/logo.png")
|
| 862 |
gr.Markdown("""\
|
|
|
|
| 902 |
|
| 903 |
predict_with_clients = partial(
|
| 904 |
GradioEvents.predict_stream,
|
| 905 |
+
bot_client=bot_client
|
|
|
|
| 906 |
)
|
| 907 |
regenerate_with_clients = partial(
|
| 908 |
GradioEvents.regenerate,
|
| 909 |
+
bot_client=bot_client
|
|
|
|
| 910 |
)
|
| 911 |
file_upload_with_clients = partial(
|
| 912 |
GradioEvents.file_upload,
|
|
|
|
| 913 |
)
|
| 914 |
|
| 915 |
chunk_size = gr.State(args.chunk_size)
|
| 916 |
file_btn.change(
|
| 917 |
fn=file_upload_with_clients,
|
| 918 |
+
inputs=[file_btn, chunk_size, faiss_db],
|
| 919 |
outputs=[progress_bar],
|
| 920 |
)
|
| 921 |
query.submit(
|
| 922 |
predict_with_clients,
|
| 923 |
+
inputs=[query, chatbot, task_history, model_name, faiss_db],
|
| 924 |
outputs=[chatbot, relevant_passage],
|
| 925 |
show_progress=True
|
| 926 |
)
|
| 927 |
query.submit(GradioEvents.reset_user_input, [], [query])
|
| 928 |
submit_btn.click(
|
| 929 |
predict_with_clients,
|
| 930 |
+
inputs=[query, chatbot, task_history, model_name, faiss_db],
|
| 931 |
outputs=[chatbot, relevant_passage],
|
| 932 |
show_progress=True,
|
| 933 |
)
|
|
|
|
| 938 |
)
|
| 939 |
regen_btn.click(
|
| 940 |
regenerate_with_clients,
|
| 941 |
+
inputs=[chatbot, task_history, model_name, faiss_db],
|
| 942 |
outputs=[chatbot, relevant_passage],
|
| 943 |
show_progress=True
|
| 944 |
)
|
|
|
|
| 956 |
faiss_db = FaissTextDatabase(args, bot_client)
|
| 957 |
|
| 958 |
# Run file upload function to save default knowledge base.
|
| 959 |
+
GradioEvents.save_file_to_db(FILE_URL_DEFAULT, args.chunk_size, faiss_db, save_file=True)
|
| 960 |
|
| 961 |
launch_demo(args, bot_client, faiss_db)
|
| 962 |
|
bot_requests.py
CHANGED
|
@@ -22,7 +22,7 @@ import json
|
|
| 22 |
import jieba
|
| 23 |
from openai import OpenAI
|
| 24 |
|
| 25 |
-
|
| 26 |
|
| 27 |
class BotClient(object):
|
| 28 |
"""Client for interacting with various AI models."""
|
|
@@ -41,15 +41,16 @@ class BotClient(object):
|
|
| 41 |
self.max_retry_num = getattr(args, 'max_retry_num', 3)
|
| 42 |
self.max_char = getattr(args, 'max_char', 8000)
|
| 43 |
|
| 44 |
-
self.
|
| 45 |
-
self.x1_model_url = getattr(args, 'x1_model_url', 'x1_model_url')
|
| 46 |
self.api_key = os.environ.get("API_KEY")
|
| 47 |
|
| 48 |
-
self.
|
| 49 |
-
self.qianfan_api_key = getattr(args, 'qianfan_api_key', 'qianfan_api_key')
|
| 50 |
self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
|
| 51 |
|
| 52 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def call_back(self, host_url: str, req_data: dict) -> dict:
|
| 55 |
"""
|
|
@@ -130,14 +131,9 @@ class BotClient(object):
|
|
| 130 |
Returns:
|
| 131 |
dict: Dictionary containing the model's processing results.
|
| 132 |
"""
|
| 133 |
-
|
| 134 |
-
"eb-45t": self.eb45t_model_url,
|
| 135 |
-
"eb-x1": self.x1_model_url
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
model_url = model_map[model_name]
|
| 139 |
|
| 140 |
-
req_data["model"] =
|
| 141 |
req_data["max_tokens"] = max_tokens
|
| 142 |
req_data["temperature"] = temperature
|
| 143 |
req_data["top_p"] = top_p
|
|
@@ -157,7 +153,6 @@ class BotClient(object):
|
|
| 157 |
res = {}
|
| 158 |
if len(res) != 0 and "error" not in res:
|
| 159 |
break
|
| 160 |
-
self.logger.info(json.dumps(res, ensure_ascii=False))
|
| 161 |
|
| 162 |
return res
|
| 163 |
|
|
@@ -183,13 +178,8 @@ class BotClient(object):
|
|
| 183 |
Yields:
|
| 184 |
dict: Dictionary containing the model's processing results.
|
| 185 |
"""
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
"eb-x1": self.x1_model_url
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
model_url = model_map[model_name]
|
| 192 |
-
req_data["model"] = "ernie-4.5-turbo-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
|
| 193 |
req_data["max_tokens"] = max_tokens
|
| 194 |
req_data["temperature"] = temperature
|
| 195 |
req_data["top_p"] = top_p
|
|
@@ -282,7 +272,7 @@ class BotClient(object):
|
|
| 282 |
to_remove = total_units - self.max_char
|
| 283 |
|
| 284 |
# 1. Truncate historical messages
|
| 285 |
-
for i in range(
|
| 286 |
if to_remove <= 0:
|
| 287 |
break
|
| 288 |
|
|
@@ -362,27 +352,39 @@ class BotClient(object):
|
|
| 362 |
Returns:
|
| 363 |
list: A list of floats representing the embedding.
|
| 364 |
"""
|
| 365 |
-
client = OpenAI(base_url=self.
|
| 366 |
response = client.embeddings.create(input=[text], model=self.embedding_model)
|
| 367 |
return response.data[0].embedding
|
| 368 |
|
| 369 |
-
|
| 370 |
"""
|
| 371 |
-
|
| 372 |
-
|
| 373 |
Args:
|
| 374 |
-
query_list (list): List of queries to
|
| 375 |
|
| 376 |
Returns:
|
| 377 |
-
list: List of
|
| 378 |
"""
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
import jieba
|
| 23 |
from openai import OpenAI
|
| 24 |
|
| 25 |
+
import requests
|
| 26 |
|
| 27 |
class BotClient(object):
|
| 28 |
"""Client for interacting with various AI models."""
|
|
|
|
| 41 |
self.max_retry_num = getattr(args, 'max_retry_num', 3)
|
| 42 |
self.max_char = getattr(args, 'max_char', 8000)
|
| 43 |
|
| 44 |
+
self.model_map = getattr(args, 'model_map', {})
|
|
|
|
| 45 |
self.api_key = os.environ.get("API_KEY")
|
| 46 |
|
| 47 |
+
self.embedding_service_url = getattr(args, 'embedding_service_url', 'embedding_service_url')
|
|
|
|
| 48 |
self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
|
| 49 |
|
| 50 |
+
self.web_search_service_url = getattr(args, 'web_search_service_url', 'web_search_service_url')
|
| 51 |
+
self.max_search_results_num = getattr(args, 'max_search_results_num', 15)
|
| 52 |
+
|
| 53 |
+
self.qianfan_api_key = os.environ.get("API_KEY")
|
| 54 |
|
| 55 |
def call_back(self, host_url: str, req_data: dict) -> dict:
|
| 56 |
"""
|
|
|
|
| 131 |
Returns:
|
| 132 |
dict: Dictionary containing the model's processing results.
|
| 133 |
"""
|
| 134 |
+
model_url = self.model_map[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
req_data["model"] = model_name
|
| 137 |
req_data["max_tokens"] = max_tokens
|
| 138 |
req_data["temperature"] = temperature
|
| 139 |
req_data["top_p"] = top_p
|
|
|
|
| 153 |
res = {}
|
| 154 |
if len(res) != 0 and "error" not in res:
|
| 155 |
break
|
|
|
|
| 156 |
|
| 157 |
return res
|
| 158 |
|
|
|
|
| 178 |
Yields:
|
| 179 |
dict: Dictionary containing the model's processing results.
|
| 180 |
"""
|
| 181 |
+
model_url = self.model_map[model_name]
|
| 182 |
+
req_data["model"] = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
req_data["max_tokens"] = max_tokens
|
| 184 |
req_data["temperature"] = temperature
|
| 185 |
req_data["top_p"] = top_p
|
|
|
|
| 272 |
to_remove = total_units - self.max_char
|
| 273 |
|
| 274 |
# 1. Truncate historical messages
|
| 275 |
+
for i in range(len(processed) - 1, 1):
|
| 276 |
if to_remove <= 0:
|
| 277 |
break
|
| 278 |
|
|
|
|
| 352 |
Returns:
|
| 353 |
list: A list of floats representing the embedding.
|
| 354 |
"""
|
| 355 |
+
client = OpenAI(base_url=self.embedding_service_url, api_key=self.qianfan_api_key)
|
| 356 |
response = client.embeddings.create(input=[text], model=self.embedding_model)
|
| 357 |
return response.data[0].embedding
|
| 358 |
|
| 359 |
+
def get_web_search_res(self, query_list: list) -> list:
|
| 360 |
"""
|
| 361 |
+
Send a request to the AI Search service using the provided API key and service URL.
|
| 362 |
+
|
| 363 |
Args:
|
| 364 |
+
query_list (list): List of queries to send to the AI Search service.
|
| 365 |
|
| 366 |
Returns:
|
| 367 |
+
list: List of responses from the AI Search service.
|
| 368 |
"""
|
| 369 |
+
headers = {
|
| 370 |
+
"Authorization": "Bearer " + self.qianfan_api_key,
|
| 371 |
+
"Content-Type": "application/json"
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
results = []
|
| 375 |
+
top_k = self.max_search_results_num // len(query_list)
|
| 376 |
+
for query in query_list:
|
| 377 |
+
payload = {
|
| 378 |
+
"messages": [{"role": "user", "content": query}],
|
| 379 |
+
"resource_type_filter": [{"type": "web", "top_k": top_k}]
|
| 380 |
+
}
|
| 381 |
+
response = requests.post(self.web_search_service_url, headers=headers, json=payload)
|
| 382 |
+
|
| 383 |
+
if response.status_code == 200:
|
| 384 |
+
response = response.json()
|
| 385 |
+
self.logger.info(response)
|
| 386 |
+
results.append(response["references"])
|
| 387 |
+
else:
|
| 388 |
+
self.logger.info(f"请求失败,状态码: {response.status_code}")
|
| 389 |
+
self.logger.info(response.text)
|
| 390 |
+
return results
|