Yoma
commited on
Commit
·
58a026c
1
Parent(s):
a6c14c5
format compatibility for Gradio 6.0
Browse files- chatbot_app.py +8 -42
- llm_interface.py +14 -2
chatbot_app.py
CHANGED
|
@@ -19,31 +19,6 @@ EMBEDDING_MODEL = 'BAAI/bge-large-en-v1.5'
|
|
| 19 |
PRODUCTS_JSON_PATH = 'products.json'
|
| 20 |
REVIEWS_JSON_PATH = 'product_reviews.json'
|
| 21 |
|
| 22 |
-
# # --- Check for and Build VectorDB if it doesn't exist ---
|
| 23 |
-
# # This is crucial for environments like HF Spaces where the file system is ephemeral.
|
| 24 |
-
# if not os.path.exists(DB_PATH):
|
| 25 |
-
# logger.info(f"ChromaDB path '{DB_PATH}' not found. Running ETL pipeline to create and populate the database.")
|
| 26 |
-
# logger.info("This may take a few moments...")
|
| 27 |
-
|
| 28 |
-
# # Check if data files exist before running ETL
|
| 29 |
-
# if not os.path.exists(PRODUCTS_JSON_PATH) or not os.path.exists(REVIEWS_JSON_PATH):
|
| 30 |
-
# logger.error(f"FATAL: Required data files ('{PRODUCTS_JSON_PATH}' or '{REVIEWS_JSON_PATH}') not found.")
|
| 31 |
-
# # Exit if data is missing, as the app cannot function
|
| 32 |
-
# exit()
|
| 33 |
-
|
| 34 |
-
# try:
|
| 35 |
-
# run_etl_pipeline(
|
| 36 |
-
# products_file=PRODUCTS_JSON_PATH,
|
| 37 |
-
# reviews_file=REVIEWS_JSON_PATH,
|
| 38 |
-
# db_path=DB_PATH,
|
| 39 |
-
# model_name=EMBEDDING_MODEL
|
| 40 |
-
# )
|
| 41 |
-
# logger.info("ETL pipeline completed successfully.")
|
| 42 |
-
# except Exception as e:
|
| 43 |
-
# logger.error(f"FATAL: An error occurred during the ETL pipeline: {e}", exc_info=True)
|
| 44 |
-
# # Exit if the ETL fails, as the app cannot function
|
| 45 |
-
# exit()
|
| 46 |
-
|
| 47 |
# 1. Instantiate the retrieval manager
|
| 48 |
# It will now connect to the newly created or existing database
|
| 49 |
retriever = RetrievalManager(db_path=DB_PATH, model_name=EMBEDDING_MODEL)
|
|
@@ -68,18 +43,18 @@ def respond(message, chat_history):
|
|
| 68 |
# 2. Moderate the user's query
|
| 69 |
if not llm_interface.moderate_query(message):
|
| 70 |
response = "I'm sorry, but your query violates our safety guidelines. I cannot process this request."
|
| 71 |
-
chat_history.append(
|
| 72 |
-
|
| 73 |
return "", chat_history, []
|
| 74 |
|
| 75 |
# 3. Rewrite the query for context
|
| 76 |
rewritten_query = llm_interface.rewrite_query(message, chat_history)
|
| 77 |
logger.info(f"Original query: '{message}' | Rewritten query: '{rewritten_query}'")
|
| 78 |
|
| 79 |
-
# 4. Retrieve relevant documents
|
| 80 |
search_results = retriever.search(rewritten_query)
|
| 81 |
|
| 82 |
-
# Process the search results
|
| 83 |
retrieved_docs = []
|
| 84 |
for collection_name, results in search_results.items():
|
| 85 |
if results and results.get('documents') and results['documents'][0]:
|
|
@@ -88,18 +63,11 @@ def respond(message, chat_history):
|
|
| 88 |
for i, doc_content in enumerate(docs):
|
| 89 |
retrieved_docs.append((doc_content, metadatas[i]))
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
# Previously, only the raw content (doc[0]) was passed to the LLM.
|
| 93 |
-
# This change ensures that key metadata fields, such as 'price', 'product_name',
|
| 94 |
-
# 'brand', and 'category', are explicitly included in the document string
|
| 95 |
-
# that the LLM processes. This makes the LLM aware of these details,
|
| 96 |
-
# allowing it to answer questions that rely on metadata.
|
| 97 |
doc_contents = []
|
| 98 |
for content, metadata in retrieved_docs:
|
| 99 |
-
# Start with the original document content
|
| 100 |
enhanced_content = content
|
| 101 |
|
| 102 |
-
# # Append key metadata fields if they exist
|
| 103 |
if metadata:
|
| 104 |
metadata_parts = []
|
| 105 |
if 'product_name' in metadata and metadata['product_name'] not in enhanced_content:
|
|
@@ -118,17 +86,15 @@ def respond(message, chat_history):
|
|
| 118 |
if metadata_parts:
|
| 119 |
enhanced_content += "\n" + ", ".join(metadata_parts)
|
| 120 |
doc_contents.append(enhanced_content)
|
| 121 |
-
|
| 122 |
-
# --- END CHANGE: Incorporate metadata into the content for the LLM ---
|
| 123 |
|
| 124 |
# 5. Generate a response using the LLM
|
| 125 |
response = llm_interface.generate_response(message, doc_contents, chat_history)
|
| 126 |
|
| 127 |
-
# 6. Append
|
| 128 |
-
chat_history.append(
|
|
|
|
| 129 |
|
| 130 |
# 7. Return values to update the Gradio UI
|
| 131 |
-
# The JSON component expects a serializable object (like a list of dicts)
|
| 132 |
docs_for_display = [
|
| 133 |
{"content": content, "metadata": metadata} for content, metadata in retrieved_docs
|
| 134 |
]
|
|
|
|
| 19 |
PRODUCTS_JSON_PATH = 'products.json'
|
| 20 |
REVIEWS_JSON_PATH = 'product_reviews.json'
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# 1. Instantiate the retrieval manager
|
| 23 |
# It will now connect to the newly created or existing database
|
| 24 |
retriever = RetrievalManager(db_path=DB_PATH, model_name=EMBEDDING_MODEL)
|
|
|
|
| 43 |
# 2. Moderate the user's query
|
| 44 |
if not llm_interface.moderate_query(message):
|
| 45 |
response = "I'm sorry, but your query violates our safety guidelines. I cannot process this request."
|
| 46 |
+
chat_history.append({"role": "user", "content": message})
|
| 47 |
+
chat_history.append({"role": "assistant", "content": response})
|
| 48 |
return "", chat_history, []
|
| 49 |
|
| 50 |
# 3. Rewrite the query for context
|
| 51 |
rewritten_query = llm_interface.rewrite_query(message, chat_history)
|
| 52 |
logger.info(f"Original query: '{message}' | Rewritten query: '{rewritten_query}'")
|
| 53 |
|
| 54 |
+
# 4. Retrieve relevant documents
|
| 55 |
search_results = retriever.search(rewritten_query)
|
| 56 |
|
| 57 |
+
# Process the search results
|
| 58 |
retrieved_docs = []
|
| 59 |
for collection_name, results in search_results.items():
|
| 60 |
if results and results.get('documents') and results['documents'][0]:
|
|
|
|
| 63 |
for i, doc_content in enumerate(docs):
|
| 64 |
retrieved_docs.append((doc_content, metadatas[i]))
|
| 65 |
|
| 66 |
+
# Incorporate metadata into content for LLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
doc_contents = []
|
| 68 |
for content, metadata in retrieved_docs:
|
|
|
|
| 69 |
enhanced_content = content
|
| 70 |
|
|
|
|
| 71 |
if metadata:
|
| 72 |
metadata_parts = []
|
| 73 |
if 'product_name' in metadata and metadata['product_name'] not in enhanced_content:
|
|
|
|
| 86 |
if metadata_parts:
|
| 87 |
enhanced_content += "\n" + ", ".join(metadata_parts)
|
| 88 |
doc_contents.append(enhanced_content)
|
|
|
|
|
|
|
| 89 |
|
| 90 |
# 5. Generate a response using the LLM
|
| 91 |
response = llm_interface.generate_response(message, doc_contents, chat_history)
|
| 92 |
|
| 93 |
+
# 6. Append messages in Gradio 6.0 format
|
| 94 |
+
chat_history.append({"role": "user", "content": message})
|
| 95 |
+
chat_history.append({"role": "assistant", "content": response})
|
| 96 |
|
| 97 |
# 7. Return values to update the Gradio UI
|
|
|
|
| 98 |
docs_for_display = [
|
| 99 |
{"content": content, "metadata": metadata} for content, metadata in retrieved_docs
|
| 100 |
]
|
llm_interface.py
CHANGED
|
@@ -293,7 +293,13 @@ def generate_response(query: str, retrieved_docs: list, history: list) -> str:
|
|
| 293 |
context = "\n\n---\n\n".join(doc for doc in retrieved_docs)
|
| 294 |
|
| 295 |
# Format chat history for the prompt
|
| 296 |
-
formatted_history = "\n".join([f"User: {user_msg}\nAssistant: {bot_msg}" for user_msg, bot_msg in history])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
prompt = system_prompt.format(context=context, chat_history=formatted_history)
|
| 299 |
|
|
@@ -357,7 +363,13 @@ def rewrite_query(query: str, history: list) -> str:
|
|
| 357 |
"""
|
| 358 |
|
| 359 |
# Format chat history for the prompt
|
| 360 |
-
formatted_history = "\n".join([f"User: {user_msg}\nAssistant: {bot_msg}" for user_msg, bot_msg in history])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
prompt = system_prompt.format(chat_history=formatted_history)
|
| 363 |
|
|
|
|
| 293 |
context = "\n\n---\n\n".join(doc for doc in retrieved_docs)
|
| 294 |
|
| 295 |
# Format chat history for the prompt
|
| 296 |
+
#formatted_history = "\n".join([f"User: {user_msg}\nAssistant: {bot_msg}" for user_msg, bot_msg in history])
|
| 297 |
+
formatted_history = ""
|
| 298 |
+
for msg in history:
|
| 299 |
+
if msg["role"] == "user":
|
| 300 |
+
formatted_history += f"User: {msg['content']}\n"
|
| 301 |
+
elif msg["role"] == "assistant":
|
| 302 |
+
formatted_history += f"Assistant: {msg['content']}\n"
|
| 303 |
|
| 304 |
prompt = system_prompt.format(context=context, chat_history=formatted_history)
|
| 305 |
|
|
|
|
| 363 |
"""
|
| 364 |
|
| 365 |
# Format chat history for the prompt
|
| 366 |
+
#formatted_history = "\n".join([f"User: {user_msg}\nAssistant: {bot_msg}" for user_msg, bot_msg in history])
|
| 367 |
+
formatted_history = ""
|
| 368 |
+
for msg in history:
|
| 369 |
+
if msg["role"] == "user":
|
| 370 |
+
formatted_history += f"User: {msg['content']}\n"
|
| 371 |
+
elif msg["role"] == "assistant":
|
| 372 |
+
formatted_history += f"Assistant: {msg['content']}\n"
|
| 373 |
|
| 374 |
prompt = system_prompt.format(chat_history=formatted_history)
|
| 375 |
|