Spaces:
Build error
Build error
Adding memory management
#8
by mukiibi - opened
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import pandas as pd
|
|
@@ -9,6 +10,8 @@ import chromadb
|
|
| 9 |
from langchain_chroma import Chroma
|
| 10 |
import gspread
|
| 11 |
from google.oauth2.service_account import Credentials
|
|
|
|
|
|
|
| 12 |
import json
|
| 13 |
from datetime import datetime
|
| 14 |
import re
|
|
@@ -36,7 +39,7 @@ client_gspread = gspread.authorize(get_google_sheets_credentials())
|
|
| 36 |
# Open the Google Sheet
|
| 37 |
sheet = client_gspread.open("Response_Log").sheet1
|
| 38 |
|
| 39 |
-
def log_response(question, answer, source_ids, knowledge_pairs):
|
| 40 |
"""
|
| 41 |
Log a question, answer, source IDs, and knowledge base question-answer pairs to the Google Sheet.
|
| 42 |
|
|
@@ -53,6 +56,7 @@ def log_response(question, answer, source_ids, knowledge_pairs):
|
|
| 53 |
knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
|
| 54 |
row = [
|
| 55 |
timestamp,
|
|
|
|
| 56 |
question,
|
| 57 |
answer,
|
| 58 |
source_ids,
|
|
@@ -69,6 +73,28 @@ def log_response(question, answer, source_ids, knowledge_pairs):
|
|
| 69 |
with open("/tmp/response_log.txt", "a") as f:
|
| 70 |
f.write(f"{timestamp},{question},{answer},{source_ids},{knowledge_question_1},{knowledge_answer_1},{knowledge_question_2},{knowledge_answer_2}\n")
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
# === Intent Classification System ===
|
| 73 |
class IntentClassifier:
|
| 74 |
def __init__(self):
|
|
@@ -203,106 +229,101 @@ def process_context(results, cosine_scores, max_results=2):
|
|
| 203 |
knowledge_pairs.append((question, answer))
|
| 204 |
return formatted_context, source_ids, knowledge_pairs
|
| 205 |
|
| 206 |
-
# === LLM Generation ===
|
| 207 |
-
def generate_xeno_response(context, question):
|
|
|
|
| 208 |
model = genai.GenerativeModel(llm_model_name)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
{question}"
|
|
|
|
| 214 |
response = model.generate_content(prompt)
|
| 215 |
return response.text.strip()
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
|
|
|
| 219 |
"""
|
| 220 |
-
|
| 221 |
"""
|
| 222 |
-
|
| 223 |
-
intent, direct_response = intent_classifier.classify_intent(message)
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
return direct_response
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
if len(message.strip()) < 3:
|
| 234 |
-
answer = "I'd be happy to help! Could you please provide more details about what you'd like to know
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
content=message,
|
| 244 |
-
task_type="retrieval_query"
|
| 245 |
-
)['embedding']
|
| 246 |
-
|
| 247 |
-
cosine_scores = []
|
| 248 |
-
for doc in queried_results:
|
| 249 |
-
doc_embedding = genai.embed_content(
|
| 250 |
-
model=embedding_model,
|
| 251 |
-
content=doc.page_content,
|
| 252 |
-
task_type="retrieval_document"
|
| 253 |
-
)['embedding']
|
| 254 |
-
cos_sim = util.cos_sim(
|
| 255 |
-
torch.tensor(query_embedding).float(),
|
| 256 |
-
torch.tensor(doc_embedding).float()
|
| 257 |
-
)[0][0].item()
|
| 258 |
-
cosine_scores.append(cos_sim)
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
return answer
|
| 270 |
-
|
| 271 |
-
except Exception as e:
|
| 272 |
-
answer = "I apologize, but I'm experiencing a technical issue. Please contact XENO support directly for assistance with your query."
|
| 273 |
-
log_response(message, answer, "N/A", [])
|
| 274 |
-
return answer
|
| 275 |
-
|
| 276 |
-
# Handle goodbye intent (not simple, but has direct response)
|
| 277 |
-
if intent == 'goodbye' and direct_response:
|
| 278 |
-
log_response(message, direct_response, "N/A", [])
|
| 279 |
-
return direct_response
|
| 280 |
|
| 281 |
-
|
| 282 |
-
answer
|
| 283 |
-
|
| 284 |
return answer
|
| 285 |
|
| 286 |
# === Enhanced Gradio UI ===
|
| 287 |
-
def
|
| 288 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
title=" ASKXENO",
|
| 293 |
-
description="""**Welcome to XENO AI Support!**
|
| 294 |
-
I can help you with questions about XENO financial services including:
|
| 295 |
-
• Account management and setup
|
| 296 |
-
• Transaction processes and fees
|
| 297 |
-
• Platform features and troubleshooting
|
| 298 |
-
• General service information
|
| 299 |
-
*Simply type your question below to get started!*""",
|
| 300 |
-
theme="soft"
|
| 301 |
-
)
|
| 302 |
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
-
# === Main Execution ===
|
| 306 |
if __name__ == "__main__":
|
| 307 |
iface = create_interface()
|
| 308 |
-
iface.launch(share=False)
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
import os
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
|
|
|
| 10 |
from langchain_chroma import Chroma
|
| 11 |
import gspread
|
| 12 |
from google.oauth2.service_account import Credentials
|
| 13 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 14 |
+
import sqlite3
|
| 15 |
import json
|
| 16 |
from datetime import datetime
|
| 17 |
import re
|
|
|
|
| 39 |
# Open the Google Sheet
|
| 40 |
sheet = client_gspread.open("Response_Log").sheet1
|
| 41 |
|
| 42 |
+
def log_response(question, answer, source_ids, knowledge_pairs, session_id):
|
| 43 |
"""
|
| 44 |
Log a question, answer, source IDs, and knowledge base question-answer pairs to the Google Sheet.
|
| 45 |
|
|
|
|
| 56 |
knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
|
| 57 |
row = [
|
| 58 |
timestamp,
|
| 59 |
+
session_id,
|
| 60 |
question,
|
| 61 |
answer,
|
| 62 |
source_ids,
|
|
|
|
| 73 |
with open("/tmp/response_log.txt", "a") as f:
|
| 74 |
f.write(f"{timestamp},{question},{answer},{source_ids},{knowledge_question_1},{knowledge_answer_1},{knowledge_question_2},{knowledge_answer_2}\n")
|
| 75 |
|
| 76 |
+
# === LangGraph Memory Setup ===
|
| 77 |
+
conn = sqlite3.connect("xeno_memory.db", check_same_thread=False)
|
| 78 |
+
memory = SqliteSaver(conn=conn)
|
| 79 |
+
|
| 80 |
+
def update_memory(config, user_message, assistant_message):
|
| 81 |
+
full_checkpoint = memory.get(config) or {}
|
| 82 |
+
messages = full_checkpoint.get("channel_values", {}).get("messages", [])
|
| 83 |
+
|
| 84 |
+
messages.append({"role": "user", "content": user_message})
|
| 85 |
+
messages.append({"role": "assistant", "content": assistant_message})
|
| 86 |
+
|
| 87 |
+
checkpoint_to_save = {
|
| 88 |
+
"v": 1,
|
| 89 |
+
"id": str(uuid.uuid4()),
|
| 90 |
+
"ts": datetime.now().isoformat(),
|
| 91 |
+
"channel_values": {"messages": messages},
|
| 92 |
+
"channel_versions": {},
|
| 93 |
+
"versions_seen": {},
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
memory.put(config, checkpoint_to_save, {}, {})
|
| 97 |
+
|
| 98 |
# === Intent Classification System ===
|
| 99 |
class IntentClassifier:
|
| 100 |
def __init__(self):
|
|
|
|
| 229 |
knowledge_pairs.append((question, answer))
|
| 230 |
return formatted_context, source_ids, knowledge_pairs
|
| 231 |
|
| 232 |
+
# === LLM Generation (Refactored) ===
|
| 233 |
+
def generate_xeno_response(context, question, chat_history):
|
| 234 |
+
"""Generates a response but does NOT handle memory."""
|
| 235 |
model = genai.GenerativeModel(llm_model_name)
|
| 236 |
+
formatted_history = "\n".join(
|
| 237 |
+
[f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history]
|
| 238 |
+
) if chat_history else "None"
|
| 239 |
+
|
| 240 |
+
prompt = f"{SYSTEM_PROMPT}\n### HISTORY ###\n{formatted_history}\n### CONTEXT ###\n{context}\n### QUESTION ###\n{question}"
|
| 241 |
+
|
| 242 |
response = model.generate_content(prompt)
|
| 243 |
return response.text.strip()
|
| 244 |
|
| 245 |
+
|
| 246 |
+
# === Main Interface Logic (Refactored) ===
|
| 247 |
+
def get_context_and_answer(message, history, session_id="default"):
|
| 248 |
"""
|
| 249 |
+
Handles intent classification, RAG, and memory updates in one place.
|
| 250 |
"""
|
| 251 |
+
config = {"configurable": {"thread_id": str(session_id), "checkpoint_ns": ""}}
|
|
|
|
| 252 |
|
| 253 |
+
full_checkpoint = memory.get(config) or {}
|
| 254 |
+
chat_history = full_checkpoint.get("channel_values", {}).get("messages", [])
|
| 255 |
+
intent, direct_response = intent_classifier.classify_intent(message)
|
|
|
|
| 256 |
|
| 257 |
+
answer = ""
|
| 258 |
+
source_ids = "N/A"
|
| 259 |
+
knowledge_pairs = []
|
| 260 |
+
|
| 261 |
+
if intent != 'query':
|
| 262 |
+
answer = direct_response
|
| 263 |
+
else:
|
| 264 |
if len(message.strip()) < 3:
|
| 265 |
+
answer = "I'd be happy to help! Could you please provide more details about what you'd like to know?"
|
| 266 |
+
else:
|
| 267 |
+
try:
|
| 268 |
+
queried_results = retriever.invoke(message)
|
| 269 |
+
query_embedding = genai.embed_content(model=embedding_model, content=message, task_type="retrieval_query")['embedding']
|
| 270 |
+
|
| 271 |
+
doc_embeddings = [genai.embed_content(model=embedding_model, content=doc.page_content, task_type="retrieval_document")['embedding'] for doc in queried_results]
|
| 272 |
+
|
| 273 |
+
cosine_scores = util.cos_sim(torch.tensor(query_embedding).float(), torch.tensor(doc_embeddings).float())[0].tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
+
if max(cosine_scores) < 0.4:
|
| 276 |
+
answer = "I'm sorry, I couldn't find specific information for your question. Could you try rephrasing it, or contact XENO support directly?"
|
| 277 |
+
else:
|
| 278 |
+
context, source_ids_list, knowledge_pairs = process_context(queried_results, cosine_scores)
|
| 279 |
+
answer = generate_xeno_response(context, message, chat_history)
|
| 280 |
+
source_ids = ", ".join(source_ids_list)
|
| 281 |
|
| 282 |
+
except Exception as e:
|
| 283 |
+
print(f"Error during RAG processing: {e}")
|
| 284 |
+
answer = "I apologize, but I'm having a technical issue. Please try again shortly or contact XENO support."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
+
update_memory(config, message, answer)
|
| 287 |
+
log_response(message, answer, source_ids, knowledge_pairs, session_id)
|
| 288 |
+
|
| 289 |
return answer
|
| 290 |
|
| 291 |
# === Enhanced Gradio UI ===
|
| 292 |
+
def respond(message, history, session_id):
|
| 293 |
+
"""Gradio's main response function."""
|
| 294 |
+
if not session_id:
|
| 295 |
+
session_id = str(uuid.uuid4())
|
| 296 |
+
|
| 297 |
+
response = get_context_and_answer(message, history, session_id)
|
| 298 |
|
| 299 |
+
config = {"configurable": {"thread_id": str(session_id), "checkpoint_ns": ""}}
|
| 300 |
+
updated_messages = (memory.get(config) or {}).get("messages", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
history.append({"role": "user", "content": message})
|
| 303 |
+
history.append({"role": "assistant", "content": response})
|
| 304 |
+
|
| 305 |
+
return "", history
|
| 306 |
+
def create_interface():
|
| 307 |
+
with gr.Blocks() as demo:
|
| 308 |
+
gr.Markdown("""ASKXENO
|
| 309 |
+
|
| 310 |
+
**Welcome to XENO AI Support!**
|
| 311 |
+
I can help you with questions about XENO financial services including:
|
| 312 |
+
• Account management and setup
|
| 313 |
+
• Transaction processes and fees
|
| 314 |
+
• Platform features and troubleshooting
|
| 315 |
+
• General service information
|
| 316 |
+
*Simply type your question below to get started!*
|
| 317 |
+
""")
|
| 318 |
+
|
| 319 |
+
session_id_box = gr.Textbox(label="Session ID", value=str(uuid.uuid4()), interactive=True)
|
| 320 |
+
|
| 321 |
+
chatbot = gr.Chatbot(label="XENO Assistant", bubble_full_width=False, height=500, type="messages")
|
| 322 |
+
msg = gr.Textbox(label="Your Message", placeholder="Type your question here...")
|
| 323 |
+
|
| 324 |
+
msg.submit(respond, [msg, chatbot, session_id_box], [msg, chatbot])
|
| 325 |
+
return demo
|
| 326 |
|
|
|
|
| 327 |
if __name__ == "__main__":
|
| 328 |
iface = create_interface()
|
| 329 |
+
iface.launch(share=False, server_name="0.0.0.0", server_port=7860)
|