Spaces:
Build error
Build error
github-actions commited on
Commit ·
3cdce90
1
Parent(s): 0aa781d
Sync from GitHub
Browse files- .env.example +2 -0
- .github/workflows/tests.yml +3 -3
- XENO%20Uganda_KnowlegeBase_V1.json +0 -0
- app.py +74 -465
- docker-compose.yml +1 -0
- requirements.txt +1 -1
- src/config.py +4 -1
- src/intent_classifier.py +46 -45
- src/interface.py +121 -0
- src/knowledge_base.py +43 -31
- src/logger.py +190 -63
- src/memory.py +19 -14
- src/response_generator.py +30 -22
- src/utils.py +18 -11
- src/vector_store.py +60 -55
- tests/conftest.py +55 -32
- tests/test_app.py +411 -0
- tests/test_intent_classifier.py +62 -56
- tests/test_interface.py +135 -0
- tests/test_knowledge_base.py +77 -65
- tests/test_logger.py +95 -117
- tests/test_memory.py +51 -60
- tests/test_response_generator.py +77 -118
- tests/test_utils.py +43 -41
- tests/test_vector_store.py +221 -102
- tox.ini +25 -0
.env.example
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GOOGLE_SHEETS_CREDENTIALS = 'google_sheets_credentials.json'
|
| 2 |
+
GEMINI_API_KEY = "your_gemini_api_key_here"
|
.github/workflows/tests.yml
CHANGED
|
@@ -2,9 +2,9 @@ name: Run Tests
|
|
| 2 |
|
| 3 |
on:
|
| 4 |
push:
|
| 5 |
-
branches: [ main,
|
| 6 |
pull_request:
|
| 7 |
-
branches: [ main,
|
| 8 |
|
| 9 |
jobs:
|
| 10 |
test:
|
|
@@ -12,7 +12,7 @@ jobs:
|
|
| 12 |
|
| 13 |
strategy:
|
| 14 |
matrix:
|
| 15 |
-
python-version: ['3.
|
| 16 |
|
| 17 |
steps:
|
| 18 |
- uses: actions/checkout@v3
|
|
|
|
| 2 |
|
| 3 |
on:
|
| 4 |
push:
|
| 5 |
+
branches: [ main, development ]
|
| 6 |
pull_request:
|
| 7 |
+
branches: [ main, development ]
|
| 8 |
|
| 9 |
jobs:
|
| 10 |
test:
|
|
|
|
| 12 |
|
| 13 |
strategy:
|
| 14 |
matrix:
|
| 15 |
+
python-version: ['3.13']
|
| 16 |
|
| 17 |
steps:
|
| 18 |
- uses: actions/checkout@v3
|
XENO%20Uganda_KnowlegeBase_V1.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app.py
CHANGED
|
@@ -2,46 +2,23 @@
|
|
| 2 |
XENO Bot - AI-powered customer service assistant
|
| 3 |
Main application file with Gradio interface
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
-
import uuid
|
| 7 |
-
import gradio as gr
|
| 8 |
-
import pandas as pd
|
| 9 |
-
import torch
|
| 10 |
-
import numpy as np
|
| 11 |
-
from sentence_transformers import util
|
| 12 |
-
from google import genai
|
| 13 |
-
import chromadb
|
| 14 |
-
from langchain_chroma import Chroma
|
| 15 |
-
import gspread
|
| 16 |
-
from google.oauth2.service_account import Credentials
|
| 17 |
-
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 18 |
-
import sqlite3
|
| 19 |
-
import json
|
| 20 |
-
from datetime import datetime
|
| 21 |
-
import re
|
| 22 |
-
from typing import Dict, List, Tuple
|
| 23 |
-
import time
|
| 24 |
-
from contextlib import contextmanager
|
| 25 |
-
import threading # <--- Added for non-blocking feedback logging
|
| 26 |
import logging
|
|
|
|
| 27 |
import traceback
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
from src.config import SIMILARITY_THRESHOLD, SERVER_NAME, SERVER_PORT
|
| 32 |
-
from src.memory import create_session_config, update_memory, retrieve_memory
|
| 33 |
from src.intent_classifier import IntentClassifier
|
| 34 |
-
from src.
|
| 35 |
-
|
| 36 |
-
generate_embeddings,
|
| 37 |
-
calculate_similarity,
|
| 38 |
-
process_context
|
| 39 |
-
)
|
| 40 |
-
from src.response_generator import generate_xeno_response
|
| 41 |
from src.logger import log_response, log_timing_data
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# === Configuration ===
|
| 47 |
# Ensure API Key is set
|
|
@@ -49,351 +26,61 @@ if "GEMINI_API_KEY" not in os.environ:
|
|
| 49 |
print("WARNING: GEMINI_API_KEY environment variable not found.")
|
| 50 |
|
| 51 |
# Initialize the client
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
collection_name = "xeno_collection"
|
| 56 |
-
|
| 57 |
-
# === Google Sheets Setup ===
|
| 58 |
-
def get_google_sheets_credentials():
|
| 59 |
-
credentials_json = os.environ.get("GOOGLE_SHEETS_CREDENTIALS")
|
| 60 |
-
if not credentials_json:
|
| 61 |
-
raise ValueError("GOOGLE_SHEETS_CREDENTIALS environment variable not set.")
|
| 62 |
-
credentials_dict = json.loads(credentials_json)
|
| 63 |
-
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
|
| 64 |
-
creds = Credentials.from_service_account_info(credentials_dict, scopes=scope)
|
| 65 |
-
return creds
|
| 66 |
-
|
| 67 |
-
# Authenticate
|
| 68 |
-
try:
|
| 69 |
-
client_gspread = gspread.authorize(get_google_sheets_credentials())
|
| 70 |
-
spreadsheet = client_gspread.open("Response_Log")
|
| 71 |
-
response_sheet = spreadsheet.sheet1
|
| 72 |
-
except Exception as e:
|
| 73 |
-
print(f"Error connecting to Google Sheets: {e}")
|
| 74 |
-
# Create dummy objects if connection fails to prevent app crash during dev
|
| 75 |
-
class DummySheet:
|
| 76 |
-
def append_row(self, *args, **kwargs): pass
|
| 77 |
-
def worksheet(self, *args): return self
|
| 78 |
-
def add_worksheet(self, *args, **kwargs): return self
|
| 79 |
-
spreadsheet = DummySheet()
|
| 80 |
-
response_sheet = DummySheet()
|
| 81 |
-
|
| 82 |
-
# Setup Timing Sheet
|
| 83 |
-
try:
|
| 84 |
-
timing_sheet = spreadsheet.worksheet("Timing_Log")
|
| 85 |
-
except:
|
| 86 |
-
try:
|
| 87 |
-
timing_sheet = spreadsheet.add_worksheet(title="Timing_Log", rows="1000", cols="15")
|
| 88 |
-
headers = [
|
| 89 |
-
"Timestamp", "Session_ID", "Question", "Total_Time_MS",
|
| 90 |
-
"Intent_Classification_MS", "Memory_Retrieval_MS", "RAG_Retrieval_MS",
|
| 91 |
-
"Embedding_Generation_MS", "Similarity_Calculation_MS", "Context_Processing_MS",
|
| 92 |
-
"LLM_Generation_MS", "Memory_Update_MS", "Logging_MS", "Error_Step", "Notes"
|
| 93 |
-
]
|
| 94 |
-
timing_sheet.append_row(headers)
|
| 95 |
-
except Exception as e:
|
| 96 |
-
print(f"Could not create Timing_Log sheet: {e}")
|
| 97 |
-
timing_sheet = None
|
| 98 |
-
|
| 99 |
-
# === NEW: Setup Feedback Sheet ===
|
| 100 |
-
try:
|
| 101 |
-
feedback_sheet = spreadsheet.worksheet("Feedback_Log")
|
| 102 |
-
except:
|
| 103 |
-
try:
|
| 104 |
-
feedback_sheet = spreadsheet.add_worksheet(title="Feedback_Log", rows="1000", cols="6")
|
| 105 |
-
headers = ["Timestamp", "Session_ID", "User_Message", "Bot_Response", "Rating", "Flag_Reason"]
|
| 106 |
-
feedback_sheet.append_row(headers)
|
| 107 |
-
except Exception as e:
|
| 108 |
-
print(f"Could not create Feedback_Log sheet: {e}")
|
| 109 |
-
feedback_sheet = None
|
| 110 |
-
|
| 111 |
-
# === Logging Functions ===
|
| 112 |
-
|
| 113 |
-
def log_response(question, answer, source_ids, knowledge_pairs, session_id):
|
| 114 |
-
"""Original response logging function"""
|
| 115 |
-
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 116 |
-
knowledge_question_1 = knowledge_pairs[0][0] if len(knowledge_pairs) > 0 else "N/A"
|
| 117 |
-
knowledge_answer_1 = knowledge_pairs[0][1] if len(knowledge_pairs) > 0 else "N/A"
|
| 118 |
-
knowledge_question_2 = knowledge_pairs[1][0] if len(knowledge_pairs) > 1 else "N/A"
|
| 119 |
-
knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
|
| 120 |
-
row = [
|
| 121 |
-
timestamp, session_id, question, answer, source_ids,
|
| 122 |
-
knowledge_question_1, knowledge_answer_1, knowledge_question_2, knowledge_answer_2
|
| 123 |
-
]
|
| 124 |
-
try:
|
| 125 |
-
response_sheet.append_row(row)
|
| 126 |
-
print(f"Logged response: {question} | Source IDs: {source_ids}")
|
| 127 |
-
except Exception as e:
|
| 128 |
-
print(f"Failed to log to Google Sheet: {e}")
|
| 129 |
-
with open("/tmp/response_log.txt", "a") as f:
|
| 130 |
-
f.write(f"{timestamp},{question},{answer},{source_ids}\n")
|
| 131 |
-
|
| 132 |
-
def log_timing_data(question, session_id, timing_summary, error_step=None, notes=None):
|
| 133 |
-
"""Log timing data to the timing sheet"""
|
| 134 |
-
if timing_sheet is None: return
|
| 135 |
-
|
| 136 |
-
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 137 |
-
step_times = timing_summary['step_times']
|
| 138 |
-
|
| 139 |
-
row = [
|
| 140 |
-
timestamp,
|
| 141 |
-
session_id,
|
| 142 |
-
question[:100] + "..." if len(question) > 100 else question,
|
| 143 |
-
timing_summary['total_time_ms'],
|
| 144 |
-
step_times.get('intent_classification', 0),
|
| 145 |
-
step_times.get('memory_retrieval', 0),
|
| 146 |
-
step_times.get('rag_retrieval', 0),
|
| 147 |
-
step_times.get('embedding_generation', 0),
|
| 148 |
-
step_times.get('similarity_calculation', 0),
|
| 149 |
-
step_times.get('context_processing', 0),
|
| 150 |
-
step_times.get('llm_generation', 0),
|
| 151 |
-
step_times.get('memory_update', 0),
|
| 152 |
-
step_times.get('response_logging', 0),
|
| 153 |
-
error_step or "",
|
| 154 |
-
notes or ""
|
| 155 |
-
]
|
| 156 |
-
|
| 157 |
-
try:
|
| 158 |
-
timing_sheet.append_row(row)
|
| 159 |
-
print(f"Logged timing data: Total {timing_summary['total_time_ms']}ms")
|
| 160 |
-
except Exception as e:
|
| 161 |
-
print(f"Failed to log timing data: {e}")
|
| 162 |
-
|
| 163 |
-
# === NEW: Feedback Functions ===
|
| 164 |
-
|
| 165 |
-
def _log_feedback_background(row):
|
| 166 |
-
"""Helper to run network request in background thread"""
|
| 167 |
-
try:
|
| 168 |
-
if feedback_sheet:
|
| 169 |
-
feedback_sheet.append_row(row)
|
| 170 |
-
print("Feedback logged successfully.")
|
| 171 |
-
else:
|
| 172 |
-
print("Feedback sheet not available.")
|
| 173 |
-
except Exception as e:
|
| 174 |
-
print(f"Failed to log feedback: {e}")
|
| 175 |
-
|
| 176 |
-
def submit_feedback(rating, reason, history, session_id):
|
| 177 |
-
"""
|
| 178 |
-
Handles user feedback submission.
|
| 179 |
-
rating: 'Positive' or 'Negative'
|
| 180 |
-
reason: User provided text
|
| 181 |
-
history: Gradio chat history list
|
| 182 |
-
"""
|
| 183 |
-
if not history or len(history) == 0:
|
| 184 |
-
return "No conversation to rate yet."
|
| 185 |
-
|
| 186 |
-
# Get the last interaction (Gradio history is a list of lists: [[user, bot], ...])
|
| 187 |
-
last_interaction = history[-1]
|
| 188 |
-
|
| 189 |
-
# Safety check for history format
|
| 190 |
-
if isinstance(last_interaction, list) and len(last_interaction) >= 2:
|
| 191 |
-
user_msg = last_interaction[0]
|
| 192 |
-
bot_msg = last_interaction[1]
|
| 193 |
-
else:
|
| 194 |
-
return "Error reading conversation history."
|
| 195 |
-
|
| 196 |
-
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 197 |
-
|
| 198 |
-
# Prepare row data
|
| 199 |
-
row = [timestamp, session_id, user_msg, bot_msg, rating, reason]
|
| 200 |
-
|
| 201 |
-
# Run in thread to prevent UI blocking
|
| 202 |
-
threading.Thread(target=_log_feedback_background, args=(row,)).start()
|
| 203 |
-
|
| 204 |
-
return f"Feedback received ({rating}). Thank you!"
|
| 205 |
-
|
| 206 |
-
# === LangGraph Memory Setup ===
|
| 207 |
-
conn = sqlite3.connect("xeno_memory.db", check_same_thread=False)
|
| 208 |
-
memory = SqliteSaver(conn=conn)
|
| 209 |
-
|
| 210 |
-
def update_memory(config, user_message, assistant_message):
|
| 211 |
-
with timer.time_step("memory_update"):
|
| 212 |
-
full_checkpoint = memory.get(config) or {}
|
| 213 |
-
messages = full_checkpoint.get("channel_values", {}).get("messages", [])
|
| 214 |
-
|
| 215 |
-
messages.append({"role": "user", "content": user_message})
|
| 216 |
-
messages.append({"role": "assistant", "content": assistant_message})
|
| 217 |
-
|
| 218 |
-
checkpoint_to_save = {
|
| 219 |
-
"v": 1,
|
| 220 |
-
"id": str(uuid.uuid4()),
|
| 221 |
-
"ts": datetime.now().isoformat(),
|
| 222 |
-
"channel_values": {"messages": messages},
|
| 223 |
-
"channel_versions": {},
|
| 224 |
-
"versions_seen": {},
|
| 225 |
-
}
|
| 226 |
-
|
| 227 |
-
memory.put(config, checkpoint_to_save, {}, {})
|
| 228 |
-
|
| 229 |
-
def retrieve_memory(config):
|
| 230 |
-
with timer.time_step("memory_retrieval"):
|
| 231 |
-
full_checkpoint = memory.get(config) or {}
|
| 232 |
-
return full_checkpoint.get("channel_values", {}).get("messages", [])
|
| 233 |
|
| 234 |
# === Intent Classification System ===
|
| 235 |
-
class IntentClassifier:
|
| 236 |
-
def __init__(self):
|
| 237 |
-
self.intent_patterns = {
|
| 238 |
-
'greeting': {
|
| 239 |
-
'patterns': [
|
| 240 |
-
r'\b(hi|hello|hey|good morning|good afternoon|good evening|greetings)\b',
|
| 241 |
-
r'^(hi|hello|hey)[\s!.]*$',
|
| 242 |
-
r'\b(how are you|how do you do)\b'
|
| 243 |
-
],
|
| 244 |
-
'responses': [
|
| 245 |
-
"Hello! I'm XENO Assistant. How can I help you with XENO financial services today?",
|
| 246 |
-
"Hi there! I'm here to assist you with any questions about XENO services. What can I help you with?",
|
| 247 |
-
"Good day! Welcome to XENO Support. How may I assist you today?"
|
| 248 |
-
]
|
| 249 |
-
},
|
| 250 |
-
'thanks': {
|
| 251 |
-
'patterns': [
|
| 252 |
-
r'\b(thank you|thanks|thank u|thx|appreciate|grateful)\b',
|
| 253 |
-
r'^(thanks|thank you)[\s!.]*$',
|
| 254 |
-
r'\b(much appreciated|thanks a lot|thank you so much)\b'
|
| 255 |
-
],
|
| 256 |
-
'responses': [
|
| 257 |
-
"You're welcome! Is there anything else I can help you with regarding XENO services?",
|
| 258 |
-
"Happy to help! Feel free to ask if you have any other questions about XENO.",
|
| 259 |
-
"Glad I could assist you! Let me know if you need help with anything else."
|
| 260 |
-
]
|
| 261 |
-
},
|
| 262 |
-
'goodbye': {
|
| 263 |
-
'patterns': [
|
| 264 |
-
r'\b(bye|goodbye|see you|farewell|take care|have a good day)\b',
|
| 265 |
-
r'^(bye|goodbye)[\s!.]*$',
|
| 266 |
-
r'\b(talk to you later|see you later|until next time)\b'
|
| 267 |
-
],
|
| 268 |
-
'responses': [
|
| 269 |
-
"Goodbye! Thank you for using XENO services. Have a great day!",
|
| 270 |
-
"Take care! Feel free to return anytime you need help with XENO services.",
|
| 271 |
-
"Have a wonderful day! Don't hesitate to reach out if you need assistance with XENO."
|
| 272 |
-
]
|
| 273 |
-
}
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
-
def classify_intent(self, message: str) -> Tuple[str, str]:
|
| 277 |
-
message_lower = message.lower().strip()
|
| 278 |
-
for intent_name, intent_data in self.intent_patterns.items():
|
| 279 |
-
for pattern in intent_data['patterns']:
|
| 280 |
-
if re.search(pattern, message_lower, re.IGNORECASE):
|
| 281 |
-
import random
|
| 282 |
-
response = random.choice(intent_data['responses'])
|
| 283 |
-
return intent_name, response
|
| 284 |
-
return 'query', ''
|
| 285 |
-
|
| 286 |
intent_classifier = IntentClassifier()
|
| 287 |
|
| 288 |
# === Load and Clean Knowledge Base ===
|
| 289 |
-
|
| 290 |
-
df_kb = pd.read_json("XENO_Uganda_KnowledgeBase_Advisory.json")
|
| 291 |
-
df_kb.dropna(subset=['Content'], inplace=True)
|
| 292 |
-
|
| 293 |
-
def prepare_documents(data):
|
| 294 |
-
documents, metadatas, ids = [], [], []
|
| 295 |
-
for item in data:
|
| 296 |
-
documents.append(f"Question: {item['Question']}\nAnswer: {item['Content']}")
|
| 297 |
-
metadatas.append({
|
| 298 |
-
"question": item["Question"],
|
| 299 |
-
"content": item["Content"],
|
| 300 |
-
"id": str(item["ID"])
|
| 301 |
-
})
|
| 302 |
-
ids.append(str(item["ID"]))
|
| 303 |
-
return documents, metadatas, ids
|
| 304 |
-
|
| 305 |
-
xeno_data_list = df_kb.to_dict('records')
|
| 306 |
-
documents, metadatas, ids = prepare_documents(xeno_data_list)
|
| 307 |
-
except Exception as e:
|
| 308 |
-
print(f"Warning: Could not load JSON knowledge base: {e}")
|
| 309 |
-
documents, metadatas, ids = [], [], []
|
| 310 |
|
| 311 |
# === Setup ChromaDB ===
|
| 312 |
-
|
| 313 |
-
client = chromadb.PersistentClient(path="/tmp/xeno_db")
|
| 314 |
-
try:
|
| 315 |
-
collection = client.get_collection(name=collection_name)
|
| 316 |
-
print(f"Loaded existing ChromaDB collection: {collection_name}")
|
| 317 |
-
except:
|
| 318 |
-
print(f"Creating new ChromaDB collection: {collection_name}")
|
| 319 |
-
collection = client.create_collection(name=collection_name)
|
| 320 |
-
if documents:
|
| 321 |
-
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
| 322 |
-
except Exception as e:
|
| 323 |
-
print(f"Failed to initialize ChromaDB: {e}")
|
| 324 |
-
raise
|
| 325 |
|
| 326 |
-
vector_store = Chroma(client=client, collection_name=collection_name)
|
| 327 |
-
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4})
|
| 328 |
-
|
| 329 |
-
# === Prompt System ===
|
| 330 |
-
SYSTEM_PROMPT = """You are a friendly XENO Support Assistant, an AI-powered helpful and professional customer service representative.
|
| 331 |
-
Use only the information provided in the knowledge base context to answer user queries.
|
| 332 |
-
Do not hallucinate. If context doesn't contain relevant info, say so in a calm polite manner by saying I'm sorry, I can't assist with that.
|
| 333 |
-
Only use context that is clearly relevant to the user's question.
|
| 334 |
-
For greetings like "hi" or "hello", respond politely without using the context.
|
| 335 |
-
remember previous conversations."""
|
| 336 |
-
|
| 337 |
-
# === Context Processing ===
|
| 338 |
-
def process_context(results, cosine_scores, max_results=2):
|
| 339 |
-
with timer.time_step("context_processing"):
|
| 340 |
-
sorted_indices = np.argsort(cosine_scores)[::-1][:max_results]
|
| 341 |
-
formatted_context = ""
|
| 342 |
-
source_ids = []
|
| 343 |
-
knowledge_pairs = []
|
| 344 |
-
for i, idx in enumerate(sorted_indices, 1):
|
| 345 |
-
result = results[idx]
|
| 346 |
-
score = cosine_scores[idx]
|
| 347 |
-
question = result.metadata.get('question', 'N/A')
|
| 348 |
-
answer = result.metadata.get('content', 'N/A')
|
| 349 |
-
formatted_context += f"Knowledge Entry {i}:\n"
|
| 350 |
-
formatted_context += f"Q: {question}\n"
|
| 351 |
-
formatted_context += f"A: {answer}\n"
|
| 352 |
-
formatted_context += "-" * 40 + "\n"
|
| 353 |
-
source_ids.append(str(result.metadata.get('id', 'N/A')))
|
| 354 |
-
knowledge_pairs.append((question, answer))
|
| 355 |
-
return formatted_context, source_ids, knowledge_pairs
|
| 356 |
-
|
| 357 |
-
# === LLM Generation ===
|
| 358 |
-
def generate_xeno_response(context, question, chat_history):
|
| 359 |
-
with timer.time_step("llm_generation"):
|
| 360 |
-
formatted_history = "\n".join(
|
| 361 |
-
[f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history]
|
| 362 |
-
) if chat_history else "None"
|
| 363 |
-
|
| 364 |
-
prompt = f"{SYSTEM_PROMPT}\n### HISTORY ###\n{formatted_history}\n### CONTEXT ###\n{context}\n### QUESTION ###\n{question}"
|
| 365 |
-
|
| 366 |
-
response = genai_client.models.generate_content(
|
| 367 |
-
model=llm_model_name,
|
| 368 |
-
contents={"text": prompt},
|
| 369 |
-
)
|
| 370 |
-
return response.text.strip()
|
| 371 |
|
| 372 |
-
# ===
|
| 373 |
-
def get_context_and_answer(
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
timer.reset()
|
| 376 |
error_step = None
|
| 377 |
notes = []
|
| 378 |
-
|
| 379 |
try:
|
| 380 |
-
# Create session config
|
| 381 |
-
|
| 382 |
-
|
| 383 |
# Step 1: Intent Classification
|
| 384 |
intent, direct_response = intent_classifier.classify_intent(message)
|
| 385 |
-
|
| 386 |
# Step 2: Memory Retrieval
|
| 387 |
-
chat_history = retrieve_memory(
|
| 388 |
-
|
| 389 |
answer = ""
|
| 390 |
source_ids = "N/A"
|
| 391 |
knowledge_pairs = []
|
| 392 |
|
| 393 |
-
if intent !=
|
| 394 |
answer = direct_response
|
| 395 |
notes.append(f"Simple intent: {intent}")
|
| 396 |
-
else:
|
| 397 |
if len(message.strip()) < 3:
|
| 398 |
answer = "I'd be happy to help! Could you please provide more details about what you'd like to know?"
|
| 399 |
notes.append("Message too short")
|
|
@@ -402,17 +89,19 @@ def get_context_and_answer(message, history, session_id="default"):
|
|
| 402 |
# Step 3: RAG Retrieval
|
| 403 |
with timer.time_step("rag_retrieval"):
|
| 404 |
queried_results = retriever.invoke(message)
|
| 405 |
-
|
| 406 |
# Step 4: Embedding Generation
|
| 407 |
query_embedding, doc_embeddings = generate_embeddings(
|
| 408 |
message, queried_results, timer
|
| 409 |
)
|
| 410 |
-
|
| 411 |
# Step 5: Similarity Calculation
|
| 412 |
with timer.time_step("similarity_calculation"):
|
|
|
|
|
|
|
| 413 |
cosine_scores = util.cos_sim(
|
| 414 |
-
torch.tensor(query_embedding).float(),
|
| 415 |
-
torch.tensor(doc_embeddings).float()
|
| 416 |
)[0].tolist()
|
| 417 |
max_score = max(cosine_scores) if cosine_scores else 0
|
| 418 |
|
|
@@ -421,8 +110,10 @@ def get_context_and_answer(message, history, session_id="default"):
|
|
| 421 |
notes.append(f"Low similarity score: {max_score:.3f}")
|
| 422 |
else:
|
| 423 |
# Step 6: Context Processing
|
| 424 |
-
context, source_ids_list, knowledge_pairs = process_context(
|
| 425 |
-
|
|
|
|
|
|
|
| 426 |
# Step 7: LLM Generation
|
| 427 |
answer = generate_xeno_response(context, message, chat_history)
|
| 428 |
source_ids = ", ".join(source_ids_list)
|
|
@@ -436,126 +127,44 @@ def get_context_and_answer(message, history, session_id="default"):
|
|
| 436 |
notes.append(f"Error: {str(e)}")
|
| 437 |
|
| 438 |
# Step 8: Memory Update
|
| 439 |
-
update_memory(
|
| 440 |
-
|
| 441 |
# Step 9: Response Logging
|
| 442 |
log_response(message, answer, source_ids, knowledge_pairs, session_id)
|
| 443 |
-
|
| 444 |
# Log timing data
|
| 445 |
timing_summary = timer.get_timing_summary()
|
| 446 |
log_timing_data(
|
| 447 |
-
message,
|
| 448 |
-
session_id,
|
| 449 |
-
timing_summary,
|
| 450 |
error_step=error_step,
|
| 451 |
-
notes="; ".join(notes) if notes else None
|
| 452 |
)
|
| 453 |
-
|
| 454 |
return answer
|
| 455 |
-
|
| 456 |
except Exception as e:
|
| 457 |
error_step = timer.current_step or "main_pipeline"
|
| 458 |
logging.error(f"Error in main pipeline: {e}")
|
| 459 |
logging.error(traceback.format_exc())
|
| 460 |
-
|
| 461 |
timing_summary = timer.get_timing_summary()
|
| 462 |
log_timing_data(
|
| 463 |
-
message,
|
| 464 |
-
session_id,
|
| 465 |
-
timing_summary,
|
| 466 |
error_step=error_step,
|
| 467 |
-
notes=f"Pipeline error: {str(e)}"
|
| 468 |
)
|
| 469 |
-
|
| 470 |
-
return "I apologize, but I encountered an error processing your request. Please try again."
|
| 471 |
-
|
| 472 |
|
| 473 |
-
|
| 474 |
-
def respond(message: str, history: List, session_id: str):
|
| 475 |
-
"""Gradio's main response function"""
|
| 476 |
-
if not session_id:
|
| 477 |
-
session_id = str(uuid.uuid4())
|
| 478 |
-
|
| 479 |
-
bot_response = get_context_and_answer(message, history, session_id)
|
| 480 |
-
history.append([message, bot_response])
|
| 481 |
-
|
| 482 |
-
return "", history
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
def create_interface():
|
| 486 |
-
"""Create Gradio interface"""
|
| 487 |
-
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 488 |
-
gr.Markdown("""
|
| 489 |
-
# ASKXENO
|
| 490 |
-
**Welcome to XENO AI Support!**
|
| 491 |
-
|
| 492 |
-
I can help you with questions about XENO financial services including:
|
| 493 |
-
- Account management and setup
|
| 494 |
-
- Transaction processes and fees
|
| 495 |
-
- Platform features and troubleshooting
|
| 496 |
-
- General service information
|
| 497 |
-
|
| 498 |
-
*Simply type your question below to get started!*
|
| 499 |
-
""")
|
| 500 |
-
|
| 501 |
-
# Hidden state for session
|
| 502 |
-
session_id_box = gr.Textbox(label="Session ID", value=str(uuid.uuid4()), visible=False)
|
| 503 |
-
|
| 504 |
-
chatbot = gr.Chatbot(
|
| 505 |
-
label="XENO Assistant",
|
| 506 |
-
bubble_full_width=False,
|
| 507 |
-
height=450
|
| 508 |
-
)
|
| 509 |
-
|
| 510 |
-
with gr.Row():
|
| 511 |
-
msg = gr.Textbox(
|
| 512 |
-
label="Your Message",
|
| 513 |
-
placeholder="Type your question here...",
|
| 514 |
-
scale=4,
|
| 515 |
-
)
|
| 516 |
-
send_button = gr.Button("Send", variant="primary", scale=1)
|
| 517 |
-
|
| 518 |
-
# ===== FEEDBACK SECTION =====
|
| 519 |
-
with gr.Row():
|
| 520 |
-
with gr.Accordion("Rate this response / Flag Issue", open=False):
|
| 521 |
-
with gr.Row():
|
| 522 |
-
thumbs_up = gr.Button("👍 Good Answer")
|
| 523 |
-
thumbs_down = gr.Button("👎 Bad / Flag")
|
| 524 |
-
|
| 525 |
-
feedback_reason = gr.Textbox(
|
| 526 |
-
label="Reason ",
|
| 527 |
-
placeholder="E.g., Incorrect fees, hallucination,"
|
| 528 |
-
)
|
| 529 |
-
feedback_status = gr.Label(value="", label="Status", show_label=False)
|
| 530 |
-
|
| 531 |
-
# Feedback Event Listeners
|
| 532 |
-
# Logic: If Thumbs Up is clicked, send 'Positive'. If Textbox is empty, reason defaults to "Good".
|
| 533 |
-
thumbs_up.click(
|
| 534 |
-
fn=lambda h, s, r: submit_feedback("Positive", r if r else "Good", h, s),
|
| 535 |
-
inputs=[chatbot, session_id_box, feedback_reason],
|
| 536 |
-
outputs=[feedback_status]
|
| 537 |
-
)
|
| 538 |
-
|
| 539 |
-
# Logic: If Thumbs Down is clicked, send 'Negative' with the content of the textbox.
|
| 540 |
-
thumbs_down.click(
|
| 541 |
-
fn=lambda r, h, s: submit_feedback("Negative", r, h, s),
|
| 542 |
-
inputs=[feedback_reason, chatbot, session_id_box],
|
| 543 |
-
outputs=[feedback_status]
|
| 544 |
-
)
|
| 545 |
-
# =============================
|
| 546 |
|
| 547 |
-
# Chat Event Listeners
|
| 548 |
-
send_button.click(respond, [msg, chatbot, session_id_box], [msg, chatbot])
|
| 549 |
-
msg.submit(respond, [msg, chatbot, session_id_box], [msg, chatbot])
|
| 550 |
-
|
| 551 |
-
return demo
|
| 552 |
|
|
|
|
| 553 |
|
| 554 |
if __name__ == "__main__":
|
| 555 |
-
iface = create_interface()
|
| 556 |
iface.launch(
|
| 557 |
-
share=False,
|
| 558 |
-
|
| 559 |
-
server_port=SERVER_PORT,
|
| 560 |
-
ssr_mode=False
|
| 561 |
-
)
|
|
|
|
| 2 |
XENO Bot - AI-powered customer service assistant
|
| 3 |
Main application file with Gradio interface
|
| 4 |
"""
|
| 5 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import logging
|
| 7 |
+
import os
|
| 8 |
import traceback
|
| 9 |
|
| 10 |
+
from src.config import (COLLECTION_NAME, EMBEDDING_MODEL, LLM_MODEL_NAME,
|
| 11 |
+
SERVER_NAME, SERVER_PORT, SIMILARITY_THRESHOLD)
|
|
|
|
|
|
|
| 12 |
from src.intent_classifier import IntentClassifier
|
| 13 |
+
from src.interface import create_interface
|
| 14 |
+
from src.knowledge_base import get_knowledge_base_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from src.logger import log_response, log_timing_data
|
| 16 |
+
from src.memory import create_session_config, retrieve_memory, update_memory
|
| 17 |
+
from src.response_generator import generate_xeno_response
|
| 18 |
+
# Import custom modules
|
| 19 |
+
from src.utils import PipelineTimer
|
| 20 |
+
from src.vector_store import (generate_embeddings, initialize_vector_store,
|
| 21 |
+
process_context)
|
| 22 |
|
| 23 |
# === Configuration ===
|
| 24 |
# Ensure API Key is set
|
|
|
|
| 26 |
print("WARNING: GEMINI_API_KEY environment variable not found.")
|
| 27 |
|
| 28 |
# Initialize the client
|
| 29 |
+
embedding_model = EMBEDDING_MODEL
|
| 30 |
+
llm_model_name = LLM_MODEL_NAME
|
| 31 |
+
collection_name = COLLECTION_NAME
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# === Intent Classification System ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
intent_classifier = IntentClassifier()
|
| 35 |
|
| 36 |
# === Load and Clean Knowledge Base ===
|
| 37 |
+
documents, metadatas, ids = get_knowledge_base_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# === Setup ChromaDB ===
|
| 40 |
+
collection, vector_store, retriever = initialize_vector_store()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
# === Core Orchestration Logic ===
|
| 44 |
+
def get_context_and_answer(
|
| 45 |
+
message, history, session_id, intent_classifier, retriever
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Core orchestration function that handles the RAG pipeline
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
message: User's message
|
| 52 |
+
history: Chat history
|
| 53 |
+
session_id: Session identifier
|
| 54 |
+
intent_classifier: IntentClassifier instance
|
| 55 |
+
retriever: Vector store retriever instance
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Generated answer string
|
| 59 |
+
"""
|
| 60 |
+
# Create timer per session
|
| 61 |
+
timer = PipelineTimer()
|
| 62 |
timer.reset()
|
| 63 |
error_step = None
|
| 64 |
notes = []
|
| 65 |
+
|
| 66 |
try:
|
| 67 |
+
# Create session memory config
|
| 68 |
+
memory_config = create_session_config(session_id)
|
| 69 |
+
|
| 70 |
# Step 1: Intent Classification
|
| 71 |
intent, direct_response = intent_classifier.classify_intent(message)
|
| 72 |
+
|
| 73 |
# Step 2: Memory Retrieval
|
| 74 |
+
chat_history = retrieve_memory(memory_config)
|
| 75 |
+
|
| 76 |
answer = ""
|
| 77 |
source_ids = "N/A"
|
| 78 |
knowledge_pairs = []
|
| 79 |
|
| 80 |
+
if intent != "query":
|
| 81 |
answer = direct_response
|
| 82 |
notes.append(f"Simple intent: {intent}")
|
| 83 |
+
else:
|
| 84 |
if len(message.strip()) < 3:
|
| 85 |
answer = "I'd be happy to help! Could you please provide more details about what you'd like to know?"
|
| 86 |
notes.append("Message too short")
|
|
|
|
| 89 |
# Step 3: RAG Retrieval
|
| 90 |
with timer.time_step("rag_retrieval"):
|
| 91 |
queried_results = retriever.invoke(message)
|
| 92 |
+
|
| 93 |
# Step 4: Embedding Generation
|
| 94 |
query_embedding, doc_embeddings = generate_embeddings(
|
| 95 |
message, queried_results, timer
|
| 96 |
)
|
| 97 |
+
|
| 98 |
# Step 5: Similarity Calculation
|
| 99 |
with timer.time_step("similarity_calculation"):
|
| 100 |
+
import sentence_transformers.util as util
|
| 101 |
+
import torch
|
| 102 |
cosine_scores = util.cos_sim(
|
| 103 |
+
torch.tensor(query_embedding).float(),
|
| 104 |
+
torch.tensor(doc_embeddings).float(),
|
| 105 |
)[0].tolist()
|
| 106 |
max_score = max(cosine_scores) if cosine_scores else 0
|
| 107 |
|
|
|
|
| 110 |
notes.append(f"Low similarity score: {max_score:.3f}")
|
| 111 |
else:
|
| 112 |
# Step 6: Context Processing
|
| 113 |
+
context, source_ids_list, knowledge_pairs = process_context(
|
| 114 |
+
queried_results, cosine_scores
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
# Step 7: LLM Generation
|
| 118 |
answer = generate_xeno_response(context, message, chat_history)
|
| 119 |
source_ids = ", ".join(source_ids_list)
|
|
|
|
| 127 |
notes.append(f"Error: {str(e)}")
|
| 128 |
|
| 129 |
# Step 8: Memory Update
|
| 130 |
+
update_memory(memory_config, message, answer)
|
| 131 |
+
|
| 132 |
# Step 9: Response Logging
|
| 133 |
log_response(message, answer, source_ids, knowledge_pairs, session_id)
|
| 134 |
+
|
| 135 |
# Log timing data
|
| 136 |
timing_summary = timer.get_timing_summary()
|
| 137 |
log_timing_data(
|
| 138 |
+
message,
|
| 139 |
+
session_id,
|
| 140 |
+
timing_summary,
|
| 141 |
error_step=error_step,
|
| 142 |
+
notes="; ".join(notes) if notes else None,
|
| 143 |
)
|
| 144 |
+
|
| 145 |
return answer
|
| 146 |
+
|
| 147 |
except Exception as e:
|
| 148 |
error_step = timer.current_step or "main_pipeline"
|
| 149 |
logging.error(f"Error in main pipeline: {e}")
|
| 150 |
logging.error(traceback.format_exc())
|
| 151 |
+
|
| 152 |
timing_summary = timer.get_timing_summary()
|
| 153 |
log_timing_data(
|
| 154 |
+
message,
|
| 155 |
+
session_id,
|
| 156 |
+
timing_summary,
|
| 157 |
error_step=error_step,
|
| 158 |
+
notes=f"Pipeline error: {str(e)}",
|
| 159 |
)
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
return "I apologize, but I encountered an error processing your request. Please try again."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
# === Main Interface Logic ===
|
| 165 |
|
| 166 |
if __name__ == "__main__":
|
| 167 |
+
iface = create_interface(intent_classifier, retriever)
|
| 168 |
iface.launch(
|
| 169 |
+
share=False, server_name=SERVER_NAME, server_port=SERVER_PORT, ssr_mode=False
|
| 170 |
+
)
|
|
|
|
|
|
|
|
|
docker-compose.yml
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
services:
|
| 2 |
xeno-bot:
|
|
|
|
| 3 |
build:
|
| 4 |
context: .
|
| 5 |
dockerfile: Dockerfile
|
|
|
|
| 1 |
services:
|
| 2 |
xeno-bot:
|
| 3 |
+
image: rogerzmukiibi/xeno-bot:test_v1
|
| 4 |
build:
|
| 5 |
context: .
|
| 6 |
dockerfile: Dockerfile
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@ huggingface_hub==0.25.2
|
|
| 2 |
gradio
|
| 3 |
pydantic==2.10.6
|
| 4 |
pandas
|
| 5 |
-
torch=
|
| 6 |
numpy
|
| 7 |
sentence-transformers
|
| 8 |
google-genai
|
|
|
|
| 2 |
gradio
|
| 3 |
pydantic==2.10.6
|
| 4 |
pandas
|
| 5 |
+
torch>=2.3.1
|
| 6 |
numpy
|
| 7 |
sentence-transformers
|
| 8 |
google-genai
|
src/config.py
CHANGED
|
@@ -2,7 +2,9 @@
|
|
| 2 |
Configuration module for XENO Bot
|
| 3 |
Handles environment variables and application settings
|
| 4 |
"""
|
|
|
|
| 5 |
import os
|
|
|
|
| 6 |
from google import genai
|
| 7 |
|
| 8 |
# === API Configuration ===
|
|
@@ -11,7 +13,7 @@ if not GEMINI_API_KEY:
|
|
| 11 |
raise ValueError("GEMINI_API_KEY environment variable not set.")
|
| 12 |
|
| 13 |
# Initialize the genai client
|
| 14 |
-
|
| 15 |
|
| 16 |
# === Model Configuration ===
|
| 17 |
EMBEDDING_MODEL = "text-embedding-004"
|
|
@@ -30,6 +32,7 @@ GOOGLE_SHEETS_CREDENTIALS_ENV = "GOOGLE_SHEETS_CREDENTIALS"
|
|
| 30 |
SPREADSHEET_NAME = "Response_Log"
|
| 31 |
RESPONSE_SHEET_INDEX = 0 # sheet1
|
| 32 |
TIMING_SHEET_NAME = "Timing_Log"
|
|
|
|
| 33 |
|
| 34 |
# === RAG Configuration ===
|
| 35 |
RAG_TOP_K = 4
|
|
|
|
| 2 |
Configuration module for XENO Bot
|
| 3 |
Handles environment variables and application settings
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import os
|
| 7 |
+
|
| 8 |
from google import genai
|
| 9 |
|
| 10 |
# === API Configuration ===
|
|
|
|
| 13 |
raise ValueError("GEMINI_API_KEY environment variable not set.")
|
| 14 |
|
| 15 |
# Initialize the genai client
|
| 16 |
+
genai_client = genai.Client(api_key=GEMINI_API_KEY)
|
| 17 |
|
| 18 |
# === Model Configuration ===
|
| 19 |
EMBEDDING_MODEL = "text-embedding-004"
|
|
|
|
| 32 |
SPREADSHEET_NAME = "Response_Log"
|
| 33 |
RESPONSE_SHEET_INDEX = 0 # sheet1
|
| 34 |
TIMING_SHEET_NAME = "Timing_Log"
|
| 35 |
+
FEEDBACK_SHEET_NAME = "Feedback_Log"
|
| 36 |
|
| 37 |
# === RAG Configuration ===
|
| 38 |
RAG_TOP_K = 4
|
src/intent_classifier.py
CHANGED
|
@@ -2,62 +2,63 @@
|
|
| 2 |
Intent Classification module for XENO Bot
|
| 3 |
Handles classification of user intents (greetings, thanks, goodbye, queries)
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
import random
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class IntentClassifier:
|
| 11 |
"""Classifies user intents and provides appropriate responses"""
|
| 12 |
-
|
| 13 |
def __init__(self):
|
| 14 |
self.intent_patterns = {
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
r
|
| 18 |
-
r
|
| 19 |
-
r
|
| 20 |
],
|
| 21 |
-
|
| 22 |
"Hello! I'm XENO Assistant. How can I help you with XENO financial services today?",
|
| 23 |
"Hi there! I'm here to assist you with any questions about XENO services. What can I help you with?",
|
| 24 |
-
"Good day! Welcome to XENO Support. How may I assist you today?"
|
| 25 |
-
]
|
| 26 |
},
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
r
|
| 30 |
-
r
|
| 31 |
-
r
|
| 32 |
],
|
| 33 |
-
|
| 34 |
"You're welcome! Is there anything else I can help you with regarding XENO services?",
|
| 35 |
"Happy to help! Feel free to ask if you have any other questions about XENO.",
|
| 36 |
-
"Glad I could assist you! Let me know if you need help with anything else."
|
| 37 |
-
]
|
| 38 |
},
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
r
|
| 42 |
-
r
|
| 43 |
-
r
|
| 44 |
],
|
| 45 |
-
|
| 46 |
"Goodbye! Thank you for using XENO services. Have a great day!",
|
| 47 |
"Take care! Feel free to return anytime you need help with XENO services.",
|
| 48 |
-
"Have a wonderful day! Don't hesitate to reach out if you need assistance with XENO."
|
| 49 |
-
]
|
| 50 |
-
}
|
| 51 |
}
|
| 52 |
-
|
| 53 |
def classify_intent(self, message: str, timer=None) -> Tuple[str, str]:
|
| 54 |
"""
|
| 55 |
Classify the intent of a user message
|
| 56 |
-
|
| 57 |
Args:
|
| 58 |
message: User's message
|
| 59 |
timer: Optional timer object for tracking
|
| 60 |
-
|
| 61 |
Returns:
|
| 62 |
Tuple of (intent_name, response_text)
|
| 63 |
"""
|
|
@@ -66,42 +67,42 @@ class IntentClassifier:
|
|
| 66 |
return self._classify_intent_impl(message)
|
| 67 |
else:
|
| 68 |
return self._classify_intent_impl(message)
|
| 69 |
-
|
| 70 |
def _classify_intent_impl(self, message: str) -> Tuple[str, str]:
|
| 71 |
"""Internal implementation of intent classification"""
|
| 72 |
message_lower = message.lower().strip()
|
| 73 |
-
|
| 74 |
for intent_name, intent_data in self.intent_patterns.items():
|
| 75 |
-
for pattern in intent_data[
|
| 76 |
if re.search(pattern, message_lower, re.IGNORECASE):
|
| 77 |
-
response = random.choice(intent_data[
|
| 78 |
return intent_name, response
|
| 79 |
-
|
| 80 |
-
return
|
| 81 |
-
|
| 82 |
def is_simple_intent(self, intent: str) -> bool:
|
| 83 |
"""
|
| 84 |
Check if the intent is a simple one that doesn't require RAG
|
| 85 |
-
|
| 86 |
Args:
|
| 87 |
intent: Intent name
|
| 88 |
-
|
| 89 |
Returns:
|
| 90 |
True if simple intent, False otherwise
|
| 91 |
"""
|
| 92 |
-
simple_intents = [
|
| 93 |
return intent in simple_intents
|
| 94 |
-
|
| 95 |
def add_intent(self, intent_name: str, patterns: List[str], responses: List[str]):
|
| 96 |
"""
|
| 97 |
Add a new intent to the classifier
|
| 98 |
-
|
| 99 |
Args:
|
| 100 |
intent_name: Name of the intent
|
| 101 |
patterns: List of regex patterns to match
|
| 102 |
responses: List of possible responses
|
| 103 |
"""
|
| 104 |
self.intent_patterns[intent_name] = {
|
| 105 |
-
|
| 106 |
-
|
| 107 |
}
|
|
|
|
| 2 |
Intent Classification module for XENO Bot
|
| 3 |
Handles classification of user intents (greetings, thanks, goodbye, queries)
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import random
|
| 7 |
+
import re
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
|
| 10 |
|
| 11 |
class IntentClassifier:
|
| 12 |
"""Classifies user intents and provides appropriate responses"""
|
| 13 |
+
|
| 14 |
def __init__(self):
|
| 15 |
self.intent_patterns = {
|
| 16 |
+
"greeting": {
|
| 17 |
+
"patterns": [
|
| 18 |
+
r"\b(hi|hello|hey|good morning|good afternoon|good evening|greetings)\b",
|
| 19 |
+
r"^(hi|hello|hey)[\s!.]*$",
|
| 20 |
+
r"\b(how are you|how do you do)\b",
|
| 21 |
],
|
| 22 |
+
"responses": [
|
| 23 |
"Hello! I'm XENO Assistant. How can I help you with XENO financial services today?",
|
| 24 |
"Hi there! I'm here to assist you with any questions about XENO services. What can I help you with?",
|
| 25 |
+
"Good day! Welcome to XENO Support. How may I assist you today?",
|
| 26 |
+
],
|
| 27 |
},
|
| 28 |
+
"thanks": {
|
| 29 |
+
"patterns": [
|
| 30 |
+
r"\b(thank you|thanks|thank u|thx|appreciate|grateful)\b",
|
| 31 |
+
r"^(thanks|thank you)[\s!.]*$",
|
| 32 |
+
r"\b(much appreciated|thanks a lot|thank you so much)\b",
|
| 33 |
],
|
| 34 |
+
"responses": [
|
| 35 |
"You're welcome! Is there anything else I can help you with regarding XENO services?",
|
| 36 |
"Happy to help! Feel free to ask if you have any other questions about XENO.",
|
| 37 |
+
"Glad I could assist you! Let me know if you need help with anything else.",
|
| 38 |
+
],
|
| 39 |
},
|
| 40 |
+
"goodbye": {
|
| 41 |
+
"patterns": [
|
| 42 |
+
r"\b(bye|goodbye|see you|farewell|take care|have a good day)\b",
|
| 43 |
+
r"^(bye|goodbye)[\s!.]*$",
|
| 44 |
+
r"\b(talk to you later|see you later|until next time)\b",
|
| 45 |
],
|
| 46 |
+
"responses": [
|
| 47 |
"Goodbye! Thank you for using XENO services. Have a great day!",
|
| 48 |
"Take care! Feel free to return anytime you need help with XENO services.",
|
| 49 |
+
"Have a wonderful day! Don't hesitate to reach out if you need assistance with XENO.",
|
| 50 |
+
],
|
| 51 |
+
},
|
| 52 |
}
|
| 53 |
+
|
| 54 |
def classify_intent(self, message: str, timer=None) -> Tuple[str, str]:
|
| 55 |
"""
|
| 56 |
Classify the intent of a user message
|
| 57 |
+
|
| 58 |
Args:
|
| 59 |
message: User's message
|
| 60 |
timer: Optional timer object for tracking
|
| 61 |
+
|
| 62 |
Returns:
|
| 63 |
Tuple of (intent_name, response_text)
|
| 64 |
"""
|
|
|
|
| 67 |
return self._classify_intent_impl(message)
|
| 68 |
else:
|
| 69 |
return self._classify_intent_impl(message)
|
| 70 |
+
|
| 71 |
def _classify_intent_impl(self, message: str) -> Tuple[str, str]:
|
| 72 |
"""Internal implementation of intent classification"""
|
| 73 |
message_lower = message.lower().strip()
|
| 74 |
+
|
| 75 |
for intent_name, intent_data in self.intent_patterns.items():
|
| 76 |
+
for pattern in intent_data["patterns"]:
|
| 77 |
if re.search(pattern, message_lower, re.IGNORECASE):
|
| 78 |
+
response = random.choice(intent_data["responses"])
|
| 79 |
return intent_name, response
|
| 80 |
+
|
| 81 |
+
return "query", ""
|
| 82 |
+
|
| 83 |
def is_simple_intent(self, intent: str) -> bool:
|
| 84 |
"""
|
| 85 |
Check if the intent is a simple one that doesn't require RAG
|
| 86 |
+
|
| 87 |
Args:
|
| 88 |
intent: Intent name
|
| 89 |
+
|
| 90 |
Returns:
|
| 91 |
True if simple intent, False otherwise
|
| 92 |
"""
|
| 93 |
+
simple_intents = ["greeting", "thanks"]
|
| 94 |
return intent in simple_intents
|
| 95 |
+
|
| 96 |
def add_intent(self, intent_name: str, patterns: List[str], responses: List[str]):
|
| 97 |
"""
|
| 98 |
Add a new intent to the classifier
|
| 99 |
+
|
| 100 |
Args:
|
| 101 |
intent_name: Name of the intent
|
| 102 |
patterns: List of regex patterns to match
|
| 103 |
responses: List of possible responses
|
| 104 |
"""
|
| 105 |
self.intent_patterns[intent_name] = {
|
| 106 |
+
"patterns": patterns,
|
| 107 |
+
"responses": responses,
|
| 108 |
}
|
src/interface.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from src.logger import log_feedback
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def respond(
|
| 10 |
+
message: str, history: List, session_id: str, intent_classifier, retriever
|
| 11 |
+
):
|
| 12 |
+
"""
|
| 13 |
+
Gradio's main response function
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
message: User's message
|
| 17 |
+
history: Chat history
|
| 18 |
+
session_id: Session identifier
|
| 19 |
+
intent_classifier: IntentClassifier instance
|
| 20 |
+
retriever: Vector store retriever instance
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Tuple of (empty string for input box, updated history)
|
| 24 |
+
"""
|
| 25 |
+
# Import here to avoid circular imports
|
| 26 |
+
from app import get_context_and_answer
|
| 27 |
+
|
| 28 |
+
if not session_id:
|
| 29 |
+
session_id = str(uuid.uuid4())
|
| 30 |
+
|
| 31 |
+
bot_response = get_context_and_answer(
|
| 32 |
+
message, history, session_id, intent_classifier, retriever
|
| 33 |
+
)
|
| 34 |
+
history.append([message, bot_response])
|
| 35 |
+
|
| 36 |
+
return "", history
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_interface(intent_classifier, retriever):
|
| 40 |
+
"""
|
| 41 |
+
Create Gradio interface
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
intent_classifier: IntentClassifier instance
|
| 45 |
+
retriever: Vector store retriever instance
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Gradio Blocks interface
|
| 49 |
+
"""
|
| 50 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 51 |
+
gr.Markdown("""
|
| 52 |
+
# ASKXENO
|
| 53 |
+
**Welcome to XENO AI Support!**
|
| 54 |
+
|
| 55 |
+
I can help you with questions about XENO financial services including:
|
| 56 |
+
- Account management and setup
|
| 57 |
+
- Transaction processes and fees
|
| 58 |
+
- Platform features and troubleshooting
|
| 59 |
+
- General service information
|
| 60 |
+
|
| 61 |
+
*Simply type your question below to get started!*
|
| 62 |
+
""")
|
| 63 |
+
|
| 64 |
+
# Hidden state for session
|
| 65 |
+
session_id_box = gr.Textbox(
|
| 66 |
+
label="Session ID", value=str(uuid.uuid4()), visible=False
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
chatbot = gr.Chatbot(
|
| 70 |
+
label="XENO Assistant", bubble_full_width=False, height=450
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
with gr.Row():
|
| 74 |
+
msg = gr.Textbox(
|
| 75 |
+
label="Your Message",
|
| 76 |
+
placeholder="Type your question here...",
|
| 77 |
+
scale=4,
|
| 78 |
+
)
|
| 79 |
+
send_button = gr.Button("Send", variant="primary", scale=1)
|
| 80 |
+
|
| 81 |
+
# ===== FEEDBACK SECTION =====
|
| 82 |
+
with gr.Row():
|
| 83 |
+
with gr.Accordion("Rate this response / Flag Issue", open=False):
|
| 84 |
+
with gr.Row():
|
| 85 |
+
thumbs_up = gr.Button("👍 Good Answer")
|
| 86 |
+
thumbs_down = gr.Button("👎 Bad / Flag")
|
| 87 |
+
|
| 88 |
+
feedback_reason = gr.Textbox(
|
| 89 |
+
label="Reason ", placeholder="E.g., Incorrect fees, hallucination,"
|
| 90 |
+
)
|
| 91 |
+
feedback_status = gr.Label(value="", label="Status", show_label=False)
|
| 92 |
+
|
| 93 |
+
# Feedback Event Listeners
|
| 94 |
+
# Logic: If Thumbs Up is clicked, send 'Positive'. If Textbox is empty, reason defaults to "Good".
|
| 95 |
+
thumbs_up.click(
|
| 96 |
+
fn=lambda h, s, r: log_feedback("Positive", r if r else "Good", h, s),
|
| 97 |
+
inputs=[chatbot, session_id_box, feedback_reason],
|
| 98 |
+
outputs=[feedback_status],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Logic: If Thumbs Down is clicked, send 'Negative' with the content of the textbox.
|
| 102 |
+
thumbs_down.click(
|
| 103 |
+
fn=lambda r, h, s: log_feedback("Negative", r, h, s),
|
| 104 |
+
inputs=[feedback_reason, chatbot, session_id_box],
|
| 105 |
+
outputs=[feedback_status],
|
| 106 |
+
)
|
| 107 |
+
# =============================
|
| 108 |
+
|
| 109 |
+
# Chat Event Listeners - Pass components to respond function
|
| 110 |
+
send_button.click(
|
| 111 |
+
lambda msg, chat, sid: respond(msg, chat, sid, intent_classifier, retriever),
|
| 112 |
+
[msg, chatbot, session_id_box],
|
| 113 |
+
[msg, chatbot],
|
| 114 |
+
)
|
| 115 |
+
msg.submit(
|
| 116 |
+
lambda msg, chat, sid: respond(msg, chat, sid, intent_classifier, retriever),
|
| 117 |
+
[msg, chatbot, session_id_box],
|
| 118 |
+
[msg, chatbot],
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return demo
|
src/knowledge_base.py
CHANGED
|
@@ -2,68 +2,80 @@
|
|
| 2 |
Knowledge Base module for XENO Bot
|
| 3 |
Handles loading and preparing knowledge base data
|
| 4 |
"""
|
|
|
|
|
|
|
|
|
|
| 5 |
import pandas as pd
|
| 6 |
-
|
| 7 |
from src.config import KNOWLEDGE_BASE_PATH
|
| 8 |
|
| 9 |
|
| 10 |
def load_knowledge_base(filepath: str = KNOWLEDGE_BASE_PATH) -> pd.DataFrame:
|
| 11 |
"""
|
| 12 |
Load knowledge base from JSON file
|
| 13 |
-
|
| 14 |
Args:
|
| 15 |
filepath: Path to the knowledge base JSON file
|
| 16 |
-
|
| 17 |
Returns:
|
| 18 |
DataFrame with knowledge base data
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
return df
|
| 23 |
|
| 24 |
|
| 25 |
-
def prepare_documents(
|
|
|
|
|
|
|
| 26 |
"""
|
| 27 |
Prepare documents for vector store
|
| 28 |
-
|
| 29 |
Args:
|
| 30 |
data: List of knowledge base entries
|
| 31 |
-
|
| 32 |
Returns:
|
| 33 |
Tuple of (documents, metadatas, ids)
|
| 34 |
"""
|
| 35 |
documents, metadatas, ids = [], [], []
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
return documents, metadatas, ids
|
| 58 |
|
| 59 |
|
| 60 |
def get_knowledge_base_data() -> Tuple[List[str], List[Dict], List[str]]:
|
| 61 |
"""
|
| 62 |
Load and prepare knowledge base data
|
| 63 |
-
|
| 64 |
Returns:
|
| 65 |
Tuple of (documents, metadatas, ids)
|
| 66 |
"""
|
| 67 |
df = load_knowledge_base()
|
| 68 |
-
data_list = df.to_dict(
|
| 69 |
return prepare_documents(data_list)
|
|
|
|
| 2 |
Knowledge Base module for XENO Bot
|
| 3 |
Handles loading and preparing knowledge base data
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, Hashable, List, Tuple
|
| 7 |
+
|
| 8 |
import pandas as pd
|
| 9 |
+
|
| 10 |
from src.config import KNOWLEDGE_BASE_PATH
|
| 11 |
|
| 12 |
|
| 13 |
def load_knowledge_base(filepath: str = KNOWLEDGE_BASE_PATH) -> pd.DataFrame:
|
| 14 |
"""
|
| 15 |
Load knowledge base from JSON file
|
| 16 |
+
|
| 17 |
Args:
|
| 18 |
filepath: Path to the knowledge base JSON file
|
| 19 |
+
|
| 20 |
Returns:
|
| 21 |
DataFrame with knowledge base data
|
| 22 |
"""
|
| 23 |
+
try:
|
| 24 |
+
df = pd.read_json(filepath)
|
| 25 |
+
df.dropna(subset=["Content"], inplace=True)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"Error loading knowledge base: {e}")
|
| 28 |
+
df = pd.DataFrame()
|
| 29 |
return df
|
| 30 |
|
| 31 |
|
| 32 |
+
def prepare_documents(
|
| 33 |
+
data: List[Dict[Hashable, Any]],
|
| 34 |
+
) -> Tuple[List[str], List[Dict], List[str]]:
|
| 35 |
"""
|
| 36 |
Prepare documents for vector store
|
| 37 |
+
|
| 38 |
Args:
|
| 39 |
data: List of knowledge base entries
|
| 40 |
+
|
| 41 |
Returns:
|
| 42 |
Tuple of (documents, metadatas, ids)
|
| 43 |
"""
|
| 44 |
documents, metadatas, ids = [], [], []
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
for item in data:
|
| 48 |
+
# Create document text with question and answer
|
| 49 |
+
document_text = f"Question: {item['Question']}\nAnswer: {item['Content']}"
|
| 50 |
+
documents.append(document_text)
|
| 51 |
+
|
| 52 |
+
# Create metadata
|
| 53 |
+
metadata = {
|
| 54 |
+
"question": item["Question"],
|
| 55 |
+
"content": item["Content"],
|
| 56 |
+
"section": item.get("Section", ""),
|
| 57 |
+
"source": item.get("Source", ""),
|
| 58 |
+
"owner": item.get("Owner", ""),
|
| 59 |
+
"tag": item.get("Tag", ""),
|
| 60 |
+
"id": item["ID"],
|
| 61 |
+
}
|
| 62 |
+
metadatas.append(metadata)
|
| 63 |
+
|
| 64 |
+
# Add ID
|
| 65 |
+
ids.append(item["ID"])
|
| 66 |
+
except KeyError as e:
|
| 67 |
+
print(f"Missing expected key in data item: {e}")
|
| 68 |
+
|
| 69 |
return documents, metadatas, ids
|
| 70 |
|
| 71 |
|
| 72 |
def get_knowledge_base_data() -> Tuple[List[str], List[Dict], List[str]]:
|
| 73 |
"""
|
| 74 |
Load and prepare knowledge base data
|
| 75 |
+
|
| 76 |
Returns:
|
| 77 |
Tuple of (documents, metadatas, ids)
|
| 78 |
"""
|
| 79 |
df = load_knowledge_base()
|
| 80 |
+
data_list = df.to_dict("records")
|
| 81 |
return prepare_documents(data_list)
|
src/logger.py
CHANGED
|
@@ -2,81 +2,145 @@
|
|
| 2 |
Logging module for XENO Bot
|
| 3 |
Handles Google Sheets logging for responses and timing data
|
| 4 |
"""
|
|
|
|
| 5 |
import json
|
| 6 |
import os
|
|
|
|
| 7 |
from datetime import datetime
|
| 8 |
-
from typing import
|
|
|
|
| 9 |
import gspread
|
| 10 |
from google.oauth2.service_account import Credentials
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
TIMING_SHEET_NAME
|
| 16 |
-
)
|
| 17 |
|
| 18 |
|
| 19 |
def get_google_sheets_credentials() -> Credentials:
|
| 20 |
"""
|
| 21 |
Get Google Sheets credentials from environment variable
|
| 22 |
-
|
| 23 |
Returns:
|
| 24 |
Google Sheets credentials object
|
| 25 |
"""
|
| 26 |
credentials_json = os.environ.get(GOOGLE_SHEETS_CREDENTIALS_ENV)
|
| 27 |
if not credentials_json:
|
| 28 |
-
raise ValueError(
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
credentials_dict = json.loads(credentials_json)
|
| 31 |
scope = [
|
| 32 |
-
"https://spreadsheets.google.com/feeds",
|
| 33 |
-
"https://www.googleapis.com/auth/drive"
|
| 34 |
]
|
| 35 |
creds = Credentials.from_service_account_info(credentials_dict, scopes=scope)
|
| 36 |
-
|
| 37 |
return creds
|
| 38 |
|
| 39 |
|
| 40 |
def initialize_sheets():
|
| 41 |
"""
|
| 42 |
Initialize Google Sheets client and get sheets
|
| 43 |
-
|
| 44 |
Returns:
|
| 45 |
Tuple of (response_sheet, timing_sheet)
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Get or create timing sheet
|
| 54 |
try:
|
| 55 |
timing_sheet = spreadsheet.worksheet(TIMING_SHEET_NAME)
|
| 56 |
except:
|
| 57 |
# Create timing sheet if it doesn't exist
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
# Initialize sheets
|
| 72 |
-
response_sheet, timing_sheet = initialize_sheets()
|
| 73 |
|
| 74 |
|
| 75 |
-
def log_response(
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
"""
|
| 78 |
Log response to Google Sheets
|
| 79 |
-
|
| 80 |
Args:
|
| 81 |
question: User's question
|
| 82 |
answer: Generated answer
|
|
@@ -87,28 +151,41 @@ def log_response(question: str, answer: str, source_ids: str,
|
|
| 87 |
"""
|
| 88 |
if timer:
|
| 89 |
with timer.time_step("response_logging"):
|
| 90 |
-
_log_response_impl(
|
|
|
|
|
|
|
| 91 |
else:
|
| 92 |
_log_response_impl(question, answer, source_ids, knowledge_pairs, session_id)
|
| 93 |
|
| 94 |
|
| 95 |
-
def _log_response_impl(
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"""Internal implementation of response logging"""
|
| 98 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 99 |
-
|
| 100 |
# Extract knowledge pairs
|
| 101 |
knowledge_question_1 = knowledge_pairs[0][0] if len(knowledge_pairs) > 0 else "N/A"
|
| 102 |
knowledge_answer_1 = knowledge_pairs[0][1] if len(knowledge_pairs) > 0 else "N/A"
|
| 103 |
knowledge_question_2 = knowledge_pairs[1][0] if len(knowledge_pairs) > 1 else "N/A"
|
| 104 |
knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
|
| 105 |
-
|
| 106 |
row = [
|
| 107 |
-
timestamp,
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
]
|
| 111 |
-
|
| 112 |
try:
|
| 113 |
response_sheet.append_row(row)
|
| 114 |
print(f"Logged response: {question} | Source IDs: {source_ids}")
|
|
@@ -116,14 +193,21 @@ def _log_response_impl(question: str, answer: str, source_ids: str,
|
|
| 116 |
print(f"Failed to log to Google Sheet: {e}")
|
| 117 |
# Fallback to local file
|
| 118 |
with open("/tmp/response_log.txt", "a") as f:
|
| 119 |
-
f.write(
|
|
|
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
-
def log_timing_data(
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
"""
|
| 125 |
Log timing data to Google Sheets
|
| 126 |
-
|
| 127 |
Args:
|
| 128 |
question: User's question
|
| 129 |
session_id: Session identifier
|
|
@@ -132,29 +216,29 @@ def log_timing_data(question: str, session_id: str, timing_summary: Dict,
|
|
| 132 |
notes: Additional notes
|
| 133 |
"""
|
| 134 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 135 |
-
step_times = timing_summary[
|
| 136 |
-
|
| 137 |
# Truncate long questions
|
| 138 |
truncated_question = question[:100] + "..." if len(question) > 100 else question
|
| 139 |
-
|
| 140 |
row = [
|
| 141 |
timestamp,
|
| 142 |
session_id,
|
| 143 |
truncated_question,
|
| 144 |
-
timing_summary[
|
| 145 |
-
step_times.get(
|
| 146 |
-
step_times.get(
|
| 147 |
-
step_times.get(
|
| 148 |
-
step_times.get(
|
| 149 |
-
step_times.get(
|
| 150 |
-
step_times.get(
|
| 151 |
-
step_times.get(
|
| 152 |
-
step_times.get(
|
| 153 |
-
step_times.get(
|
| 154 |
error_step or "",
|
| 155 |
-
notes or ""
|
| 156 |
]
|
| 157 |
-
|
| 158 |
try:
|
| 159 |
timing_sheet.append_row(row)
|
| 160 |
print(f"Logged timing data: Total {timing_summary['total_time_ms']}ms")
|
|
@@ -163,3 +247,46 @@ def log_timing_data(question: str, session_id: str, timing_summary: Dict,
|
|
| 163 |
# Fallback to local file
|
| 164 |
with open("/tmp/timing_log.txt", "a") as f:
|
| 165 |
f.write(f"{timestamp},{session_id},{question},{timing_summary}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
Logging module for XENO Bot
|
| 3 |
Handles Google Sheets logging for responses and timing data
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import json
|
| 7 |
import os
|
| 8 |
+
import threading
|
| 9 |
from datetime import datetime
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
import gspread
|
| 13 |
from google.oauth2.service_account import Credentials
|
| 14 |
+
|
| 15 |
+
from src.config import (FEEDBACK_SHEET_NAME, GOOGLE_SHEETS_CREDENTIALS_ENV,
|
| 16 |
+
RESPONSE_SHEET_INDEX, SPREADSHEET_NAME,
|
| 17 |
+
TIMING_SHEET_NAME)
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def get_google_sheets_credentials() -> Credentials:
|
| 21 |
"""
|
| 22 |
Get Google Sheets credentials from environment variable
|
| 23 |
+
|
| 24 |
Returns:
|
| 25 |
Google Sheets credentials object
|
| 26 |
"""
|
| 27 |
credentials_json = os.environ.get(GOOGLE_SHEETS_CREDENTIALS_ENV)
|
| 28 |
if not credentials_json:
|
| 29 |
+
raise ValueError(
|
| 30 |
+
f"{GOOGLE_SHEETS_CREDENTIALS_ENV} environment variable not set."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
credentials_dict = json.loads(credentials_json)
|
| 34 |
scope = [
|
| 35 |
+
"https://spreadsheets.google.com/feeds",
|
| 36 |
+
"https://www.googleapis.com/auth/drive",
|
| 37 |
]
|
| 38 |
creds = Credentials.from_service_account_info(credentials_dict, scopes=scope)
|
| 39 |
+
|
| 40 |
return creds
|
| 41 |
|
| 42 |
|
| 43 |
def initialize_sheets():
|
| 44 |
"""
|
| 45 |
Initialize Google Sheets client and get sheets
|
| 46 |
+
|
| 47 |
Returns:
|
| 48 |
Tuple of (response_sheet, timing_sheet)
|
| 49 |
"""
|
| 50 |
+
try:
|
| 51 |
+
client_gspread = gspread.authorize(get_google_sheets_credentials())
|
| 52 |
+
spreadsheet = client_gspread.open(SPREADSHEET_NAME)
|
| 53 |
+
|
| 54 |
+
# Get response sheet
|
| 55 |
+
response_sheet = spreadsheet.get_worksheet(RESPONSE_SHEET_INDEX)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Failed to initialize Google Sheets: {e}")
|
| 58 |
+
|
| 59 |
+
# TODO Create dummy sheets or handle error appropriately
|
| 60 |
+
class DummySheet:
|
| 61 |
+
def append_row(self, *args, **kwargs):
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
def worksheet(self, *args):
|
| 65 |
+
return self
|
| 66 |
+
|
| 67 |
+
def add_worksheet(self, *args, **kwargs):
|
| 68 |
+
return self
|
| 69 |
+
|
| 70 |
+
spreadsheet = DummySheet()
|
| 71 |
+
response_sheet = DummySheet()
|
| 72 |
+
|
| 73 |
# Get or create timing sheet
|
| 74 |
try:
|
| 75 |
timing_sheet = spreadsheet.worksheet(TIMING_SHEET_NAME)
|
| 76 |
except:
|
| 77 |
# Create timing sheet if it doesn't exist
|
| 78 |
+
try:
|
| 79 |
+
timing_sheet = spreadsheet.add_worksheet(
|
| 80 |
+
title=TIMING_SHEET_NAME, rows=1000, cols=15
|
| 81 |
+
)
|
| 82 |
+
# Add headers
|
| 83 |
+
headers = [
|
| 84 |
+
"Timestamp",
|
| 85 |
+
"Session_ID",
|
| 86 |
+
"Question",
|
| 87 |
+
"Total_Time_MS",
|
| 88 |
+
"Intent_Classification_MS",
|
| 89 |
+
"Memory_Retrieval_MS",
|
| 90 |
+
"RAG_Retrieval_MS",
|
| 91 |
+
"Embedding_Generation_MS",
|
| 92 |
+
"Similarity_Calculation_MS",
|
| 93 |
+
"Context_Processing_MS",
|
| 94 |
+
"LLM_Generation_MS",
|
| 95 |
+
"Memory_Update_MS",
|
| 96 |
+
"Logging_MS",
|
| 97 |
+
"Error_Step",
|
| 98 |
+
"Notes",
|
| 99 |
+
]
|
| 100 |
+
timing_sheet.append_row(headers)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Failed to create timing sheet: {e}")
|
| 103 |
+
timing_sheet = DummySheet()
|
| 104 |
+
|
| 105 |
+
# Feedback Sheet
|
| 106 |
+
try:
|
| 107 |
+
feedback_sheet = spreadsheet.worksheet(FEEDBACK_SHEET_NAME)
|
| 108 |
+
except:
|
| 109 |
+
try:
|
| 110 |
+
feedback_sheet = spreadsheet.add_worksheet(
|
| 111 |
+
title=FEEDBACK_SHEET_NAME, rows=1000, cols=6
|
| 112 |
+
)
|
| 113 |
+
headers = [
|
| 114 |
+
"Timestamp",
|
| 115 |
+
"Session_ID",
|
| 116 |
+
"User_Message",
|
| 117 |
+
"Bot_Response",
|
| 118 |
+
"Rating",
|
| 119 |
+
"Flag_Reason",
|
| 120 |
+
]
|
| 121 |
+
feedback_sheet.append_row(headers)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Failed to create feedback sheet: {e}")
|
| 124 |
+
feedback_sheet = DummySheet()
|
| 125 |
+
|
| 126 |
+
return response_sheet, timing_sheet, feedback_sheet
|
| 127 |
|
| 128 |
|
| 129 |
# Initialize sheets
|
| 130 |
+
response_sheet, timing_sheet, feedback_sheet = initialize_sheets()
|
| 131 |
|
| 132 |
|
| 133 |
+
def log_response(
|
| 134 |
+
question: str,
|
| 135 |
+
answer: str,
|
| 136 |
+
source_ids: str,
|
| 137 |
+
knowledge_pairs: List[Tuple[str, str]],
|
| 138 |
+
session_id: str,
|
| 139 |
+
timer=None,
|
| 140 |
+
):
|
| 141 |
"""
|
| 142 |
Log response to Google Sheets
|
| 143 |
+
|
| 144 |
Args:
|
| 145 |
question: User's question
|
| 146 |
answer: Generated answer
|
|
|
|
| 151 |
"""
|
| 152 |
if timer:
|
| 153 |
with timer.time_step("response_logging"):
|
| 154 |
+
_log_response_impl(
|
| 155 |
+
question, answer, source_ids, knowledge_pairs, session_id
|
| 156 |
+
)
|
| 157 |
else:
|
| 158 |
_log_response_impl(question, answer, source_ids, knowledge_pairs, session_id)
|
| 159 |
|
| 160 |
|
| 161 |
+
def _log_response_impl(
|
| 162 |
+
question: str,
|
| 163 |
+
answer: str,
|
| 164 |
+
source_ids: str,
|
| 165 |
+
knowledge_pairs: List[Tuple[str, str]],
|
| 166 |
+
session_id: str,
|
| 167 |
+
):
|
| 168 |
"""Internal implementation of response logging"""
|
| 169 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 170 |
+
|
| 171 |
# Extract knowledge pairs
|
| 172 |
knowledge_question_1 = knowledge_pairs[0][0] if len(knowledge_pairs) > 0 else "N/A"
|
| 173 |
knowledge_answer_1 = knowledge_pairs[0][1] if len(knowledge_pairs) > 0 else "N/A"
|
| 174 |
knowledge_question_2 = knowledge_pairs[1][0] if len(knowledge_pairs) > 1 else "N/A"
|
| 175 |
knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
|
| 176 |
+
|
| 177 |
row = [
|
| 178 |
+
timestamp,
|
| 179 |
+
session_id,
|
| 180 |
+
question,
|
| 181 |
+
answer,
|
| 182 |
+
source_ids,
|
| 183 |
+
knowledge_question_1,
|
| 184 |
+
knowledge_answer_1,
|
| 185 |
+
knowledge_question_2,
|
| 186 |
+
knowledge_answer_2,
|
| 187 |
]
|
| 188 |
+
|
| 189 |
try:
|
| 190 |
response_sheet.append_row(row)
|
| 191 |
print(f"Logged response: {question} | Source IDs: {source_ids}")
|
|
|
|
| 193 |
print(f"Failed to log to Google Sheet: {e}")
|
| 194 |
# Fallback to local file
|
| 195 |
with open("/tmp/response_log.txt", "a") as f:
|
| 196 |
+
f.write(
|
| 197 |
+
f"{timestamp},{question},{answer},{source_ids},{knowledge_question_1},{knowledge_answer_1},{knowledge_question_2},{knowledge_answer_2}\n"
|
| 198 |
+
)
|
| 199 |
|
| 200 |
|
| 201 |
+
def log_timing_data(
|
| 202 |
+
question: str,
|
| 203 |
+
session_id: str,
|
| 204 |
+
timing_summary: Dict,
|
| 205 |
+
error_step: Optional[str] = None,
|
| 206 |
+
notes: Optional[str] = None,
|
| 207 |
+
):
|
| 208 |
"""
|
| 209 |
Log timing data to Google Sheets
|
| 210 |
+
|
| 211 |
Args:
|
| 212 |
question: User's question
|
| 213 |
session_id: Session identifier
|
|
|
|
| 216 |
notes: Additional notes
|
| 217 |
"""
|
| 218 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 219 |
+
step_times = timing_summary["step_times"]
|
| 220 |
+
|
| 221 |
# Truncate long questions
|
| 222 |
truncated_question = question[:100] + "..." if len(question) > 100 else question
|
| 223 |
+
|
| 224 |
row = [
|
| 225 |
timestamp,
|
| 226 |
session_id,
|
| 227 |
truncated_question,
|
| 228 |
+
timing_summary["total_time_ms"],
|
| 229 |
+
step_times.get("intent_classification", 0),
|
| 230 |
+
step_times.get("memory_retrieval", 0),
|
| 231 |
+
step_times.get("rag_retrieval", 0),
|
| 232 |
+
step_times.get("embedding_generation", 0),
|
| 233 |
+
step_times.get("similarity_calculation", 0),
|
| 234 |
+
step_times.get("context_processing", 0),
|
| 235 |
+
step_times.get("llm_generation", 0),
|
| 236 |
+
step_times.get("memory_update", 0),
|
| 237 |
+
step_times.get("response_logging", 0),
|
| 238 |
error_step or "",
|
| 239 |
+
notes or "",
|
| 240 |
]
|
| 241 |
+
|
| 242 |
try:
|
| 243 |
timing_sheet.append_row(row)
|
| 244 |
print(f"Logged timing data: Total {timing_summary['total_time_ms']}ms")
|
|
|
|
| 247 |
# Fallback to local file
|
| 248 |
with open("/tmp/timing_log.txt", "a") as f:
|
| 249 |
f.write(f"{timestamp},{session_id},{question},{timing_summary}\n")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _log_feedback_background(row):
|
| 253 |
+
"""Helper to run network request in background thread"""
|
| 254 |
+
try:
|
| 255 |
+
if feedback_sheet:
|
| 256 |
+
feedback_sheet.append_row(row)
|
| 257 |
+
print("Feedback logged successfully.")
|
| 258 |
+
else:
|
| 259 |
+
print("Feedback sheet not available.")
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"Failed to log feedback: {e}")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def log_feedback(rating, reason, history, session_id):
|
| 265 |
+
"""
|
| 266 |
+
Handles user feedback submission.
|
| 267 |
+
rating: 'Positive' or 'Negative'
|
| 268 |
+
reason: User provided text
|
| 269 |
+
history: Gradio chat history list
|
| 270 |
+
"""
|
| 271 |
+
if not history or len(history) == 0:
|
| 272 |
+
return "No conversation to rate yet."
|
| 273 |
+
|
| 274 |
+
# Get the last interaction (Gradio history is a list of lists: [[user, bot], ...])
|
| 275 |
+
last_interaction = history[-1]
|
| 276 |
+
|
| 277 |
+
# Safety check for history format
|
| 278 |
+
if isinstance(last_interaction, list) and len(last_interaction) >= 2:
|
| 279 |
+
user_msg = last_interaction[0]
|
| 280 |
+
bot_msg = last_interaction[1]
|
| 281 |
+
else:
|
| 282 |
+
return "Error reading conversation history."
|
| 283 |
+
|
| 284 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 285 |
+
|
| 286 |
+
# Prepare row data
|
| 287 |
+
row = [timestamp, session_id, user_msg, bot_msg, rating, reason]
|
| 288 |
+
|
| 289 |
+
# Run in thread to prevent UI blocking
|
| 290 |
+
threading.Thread(target=_log_feedback_background, args=(row,)).start()
|
| 291 |
+
|
| 292 |
+
return f"Feedback received ({rating}). Thank you!"
|
src/memory.py
CHANGED
|
@@ -2,11 +2,14 @@
|
|
| 2 |
Memory module for XENO Bot
|
| 3 |
Handles LangGraph memory operations using SQLite
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
import sqlite3
|
|
|
|
| 7 |
from datetime import datetime
|
| 8 |
-
from typing import
|
|
|
|
| 9 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
|
|
| 10 |
from src.config import SQLITE_DB_PATH
|
| 11 |
|
| 12 |
# === LangGraph Memory Setup ===
|
|
@@ -14,10 +17,12 @@ conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
|
|
| 14 |
memory = SqliteSaver(conn=conn)
|
| 15 |
|
| 16 |
|
| 17 |
-
def update_memory(
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
Update memory with new messages
|
| 20 |
-
|
| 21 |
Args:
|
| 22 |
config: Configuration dictionary with thread_id
|
| 23 |
user_message: User's message
|
|
@@ -31,34 +36,34 @@ def update_memory(config: Dict[str, Any], user_message: str, assistant_message:
|
|
| 31 |
_update_memory_impl(config, user_message, assistant_message)
|
| 32 |
|
| 33 |
|
| 34 |
-
def _update_memory_impl(config
|
| 35 |
"""Internal implementation of memory update"""
|
| 36 |
full_checkpoint = memory.get(config) or {}
|
| 37 |
messages = full_checkpoint.get("channel_values", {}).get("messages", [])
|
| 38 |
-
|
| 39 |
messages.append({"role": "user", "content": user_message})
|
| 40 |
messages.append({"role": "assistant", "content": assistant_message})
|
| 41 |
-
|
| 42 |
checkpoint_to_save = {
|
| 43 |
"v": 1,
|
| 44 |
"id": str(uuid.uuid4()),
|
| 45 |
"ts": datetime.now().isoformat(),
|
| 46 |
"channel_values": {"messages": messages},
|
| 47 |
"channel_versions": {},
|
| 48 |
-
"versions_seen": {},
|
| 49 |
}
|
| 50 |
-
|
| 51 |
memory.put(config, checkpoint_to_save, {}, {})
|
| 52 |
|
| 53 |
|
| 54 |
def retrieve_memory(config: Dict[str, Any], timer=None) -> List[Dict[str, str]]:
|
| 55 |
"""
|
| 56 |
Retrieve memory messages for a session
|
| 57 |
-
|
| 58 |
Args:
|
| 59 |
config: Configuration dictionary with thread_id
|
| 60 |
timer: Optional timer object for tracking
|
| 61 |
-
|
| 62 |
Returns:
|
| 63 |
List of message dictionaries
|
| 64 |
"""
|
|
@@ -69,7 +74,7 @@ def retrieve_memory(config: Dict[str, Any], timer=None) -> List[Dict[str, str]]:
|
|
| 69 |
return _retrieve_memory_impl(config)
|
| 70 |
|
| 71 |
|
| 72 |
-
def _retrieve_memory_impl(config
|
| 73 |
"""Internal implementation of memory retrieval"""
|
| 74 |
full_checkpoint = memory.get(config) or {}
|
| 75 |
return full_checkpoint.get("channel_values", {}).get("messages", [])
|
|
@@ -78,10 +83,10 @@ def _retrieve_memory_impl(config: Dict[str, Any]) -> List[Dict[str, str]]:
|
|
| 78 |
def create_session_config(session_id: str = "default") -> Dict[str, Any]:
|
| 79 |
"""
|
| 80 |
Create a configuration dictionary for a session
|
| 81 |
-
|
| 82 |
Args:
|
| 83 |
session_id: Unique session identifier
|
| 84 |
-
|
| 85 |
Returns:
|
| 86 |
Configuration dictionary
|
| 87 |
"""
|
|
|
|
| 2 |
Memory module for XENO Bot
|
| 3 |
Handles LangGraph memory operations using SQLite
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import sqlite3
|
| 7 |
+
import uuid
|
| 8 |
from datetime import datetime
|
| 9 |
+
from typing import Any, Dict, List
|
| 10 |
+
|
| 11 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 12 |
+
|
| 13 |
from src.config import SQLITE_DB_PATH
|
| 14 |
|
| 15 |
# === LangGraph Memory Setup ===
|
|
|
|
| 17 |
memory = SqliteSaver(conn=conn)
|
| 18 |
|
| 19 |
|
| 20 |
+
def update_memory(
|
| 21 |
+
config: Dict[str, Any], user_message: str, assistant_message: str, timer=None
|
| 22 |
+
):
|
| 23 |
"""
|
| 24 |
Update memory with new messages
|
| 25 |
+
|
| 26 |
Args:
|
| 27 |
config: Configuration dictionary with thread_id
|
| 28 |
user_message: User's message
|
|
|
|
| 36 |
_update_memory_impl(config, user_message, assistant_message)
|
| 37 |
|
| 38 |
|
| 39 |
+
def _update_memory_impl(config, user_message: str, assistant_message: str):
|
| 40 |
"""Internal implementation of memory update"""
|
| 41 |
full_checkpoint = memory.get(config) or {}
|
| 42 |
messages = full_checkpoint.get("channel_values", {}).get("messages", [])
|
| 43 |
+
|
| 44 |
messages.append({"role": "user", "content": user_message})
|
| 45 |
messages.append({"role": "assistant", "content": assistant_message})
|
| 46 |
+
|
| 47 |
checkpoint_to_save = {
|
| 48 |
"v": 1,
|
| 49 |
"id": str(uuid.uuid4()),
|
| 50 |
"ts": datetime.now().isoformat(),
|
| 51 |
"channel_values": {"messages": messages},
|
| 52 |
"channel_versions": {},
|
| 53 |
+
"versions_seen": {},
|
| 54 |
}
|
| 55 |
+
|
| 56 |
memory.put(config, checkpoint_to_save, {}, {})
|
| 57 |
|
| 58 |
|
| 59 |
def retrieve_memory(config: Dict[str, Any], timer=None) -> List[Dict[str, str]]:
|
| 60 |
"""
|
| 61 |
Retrieve memory messages for a session
|
| 62 |
+
|
| 63 |
Args:
|
| 64 |
config: Configuration dictionary with thread_id
|
| 65 |
timer: Optional timer object for tracking
|
| 66 |
+
|
| 67 |
Returns:
|
| 68 |
List of message dictionaries
|
| 69 |
"""
|
|
|
|
| 74 |
return _retrieve_memory_impl(config)
|
| 75 |
|
| 76 |
|
| 77 |
+
def _retrieve_memory_impl(config) -> List[Dict[str, str]]:
|
| 78 |
"""Internal implementation of memory retrieval"""
|
| 79 |
full_checkpoint = memory.get(config) or {}
|
| 80 |
return full_checkpoint.get("channel_values", {}).get("messages", [])
|
|
|
|
| 83 |
def create_session_config(session_id: str = "default") -> Dict[str, Any]:
|
| 84 |
"""
|
| 85 |
Create a configuration dictionary for a session
|
| 86 |
+
|
| 87 |
Args:
|
| 88 |
session_id: Unique session identifier
|
| 89 |
+
|
| 90 |
Returns:
|
| 91 |
Configuration dictionary
|
| 92 |
"""
|
src/response_generator.py
CHANGED
|
@@ -2,21 +2,24 @@
|
|
| 2 |
Response Generation module for XENO Bot
|
| 3 |
Handles LLM response generation
|
| 4 |
"""
|
| 5 |
-
from google import genai
|
| 6 |
-
from typing import List, Dict
|
| 7 |
-
from src.config import LLM_MODEL_NAME, SYSTEM_PROMPT, client
|
| 8 |
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
Generate a response using the LLM
|
| 13 |
-
|
| 14 |
Args:
|
| 15 |
context: Formatted context from knowledge base
|
| 16 |
question: User's question
|
| 17 |
chat_history: List of previous messages
|
| 18 |
timer: Optional timer object for tracking
|
| 19 |
-
|
| 20 |
Returns:
|
| 21 |
Generated response text
|
| 22 |
"""
|
|
@@ -27,42 +30,47 @@ def generate_xeno_response(context: str, question: str, chat_history: List[Dict[
|
|
| 27 |
return _generate_response_impl(context, question, chat_history)
|
| 28 |
|
| 29 |
|
| 30 |
-
def _generate_response_impl(
|
|
|
|
|
|
|
| 31 |
"""Internal implementation of response generation"""
|
| 32 |
# Format chat history
|
| 33 |
-
formatted_history =
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Build prompt
|
| 38 |
prompt = f"{SYSTEM_PROMPT}\n### HISTORY ###\n{formatted_history}\n### CONTEXT ###\n{context}\n### QUESTION ###\n{question}"
|
| 39 |
-
|
| 40 |
# Generate response
|
| 41 |
-
response =
|
| 42 |
-
model=LLM_MODEL_NAME,
|
| 43 |
-
contents={"text": prompt}
|
| 44 |
)
|
| 45 |
-
|
| 46 |
return response.text
|
| 47 |
|
| 48 |
|
| 49 |
def format_chat_history(messages: List[Dict[str, str]]) -> str:
|
| 50 |
"""
|
| 51 |
Format chat history for display or logging
|
| 52 |
-
|
| 53 |
Args:
|
| 54 |
messages: List of message dictionaries with 'role' and 'content'
|
| 55 |
-
|
| 56 |
Returns:
|
| 57 |
Formatted string representation of chat history
|
| 58 |
"""
|
| 59 |
if not messages:
|
| 60 |
return "No previous conversation"
|
| 61 |
-
|
| 62 |
formatted = []
|
| 63 |
for msg in messages:
|
| 64 |
-
role = msg.get(
|
| 65 |
-
content = msg.get(
|
| 66 |
formatted.append(f"{role}: {content}")
|
| 67 |
-
|
| 68 |
return "\n".join(formatted)
|
|
|
|
| 2 |
Response Generation module for XENO Bot
|
| 3 |
Handles LLM response generation
|
| 4 |
"""
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
from typing import Dict, List
|
| 7 |
|
| 8 |
+
from src.config import LLM_MODEL_NAME, SYSTEM_PROMPT, genai_client
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_xeno_response(
|
| 12 |
+
context: str, question: str, chat_history: List[Dict[str, str]], timer=None
|
| 13 |
+
) -> str:
|
| 14 |
"""
|
| 15 |
Generate a response using the LLM
|
| 16 |
+
|
| 17 |
Args:
|
| 18 |
context: Formatted context from knowledge base
|
| 19 |
question: User's question
|
| 20 |
chat_history: List of previous messages
|
| 21 |
timer: Optional timer object for tracking
|
| 22 |
+
|
| 23 |
Returns:
|
| 24 |
Generated response text
|
| 25 |
"""
|
|
|
|
| 30 |
return _generate_response_impl(context, question, chat_history)
|
| 31 |
|
| 32 |
|
| 33 |
+
def _generate_response_impl(
|
| 34 |
+
context: str, question: str, chat_history: List[Dict[str, str]]
|
| 35 |
+
) -> str:
|
| 36 |
"""Internal implementation of response generation"""
|
| 37 |
# Format chat history
|
| 38 |
+
formatted_history = (
|
| 39 |
+
"\n".join(
|
| 40 |
+
[f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history]
|
| 41 |
+
)
|
| 42 |
+
if chat_history
|
| 43 |
+
else "None"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
# Build prompt
|
| 47 |
prompt = f"{SYSTEM_PROMPT}\n### HISTORY ###\n{formatted_history}\n### CONTEXT ###\n{context}\n### QUESTION ###\n{question}"
|
| 48 |
+
|
| 49 |
# Generate response
|
| 50 |
+
response = genai_client.models.generate_content(
|
| 51 |
+
model=LLM_MODEL_NAME, contents=prompt
|
|
|
|
| 52 |
)
|
| 53 |
+
|
| 54 |
return response.text
|
| 55 |
|
| 56 |
|
| 57 |
def format_chat_history(messages: List[Dict[str, str]]) -> str:
|
| 58 |
"""
|
| 59 |
Format chat history for display or logging
|
| 60 |
+
|
| 61 |
Args:
|
| 62 |
messages: List of message dictionaries with 'role' and 'content'
|
| 63 |
+
|
| 64 |
Returns:
|
| 65 |
Formatted string representation of chat history
|
| 66 |
"""
|
| 67 |
if not messages:
|
| 68 |
return "No previous conversation"
|
| 69 |
+
|
| 70 |
formatted = []
|
| 71 |
for msg in messages:
|
| 72 |
+
role = msg.get("role", "unknown").capitalize()
|
| 73 |
+
content = msg.get("content", "")
|
| 74 |
formatted.append(f"{role}: {content}")
|
| 75 |
+
|
| 76 |
return "\n".join(formatted)
|
src/utils.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
Utilities module for XENO Bot
|
| 3 |
Handles logging and timing functionality
|
| 4 |
"""
|
|
|
|
| 5 |
import logging
|
| 6 |
import sys
|
| 7 |
import time
|
|
@@ -13,14 +14,18 @@ from typing import Dict
|
|
| 13 |
logging.basicConfig(
|
| 14 |
filename="app.log",
|
| 15 |
level=logging.INFO,
|
| 16 |
-
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 17 |
)
|
| 18 |
|
|
|
|
| 19 |
def log_exception(exc_type, exc_value, exc_traceback):
|
| 20 |
"""Log uncaught exceptions"""
|
| 21 |
if issubclass(exc_type, KeyboardInterrupt):
|
| 22 |
return
|
| 23 |
-
logging.critical(
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
sys.excepthook = log_exception
|
| 26 |
logging.info("App started successfully.")
|
|
@@ -29,17 +34,17 @@ logging.info("App started successfully.")
|
|
| 29 |
# ===== Time Tracking Class =====
|
| 30 |
class PipelineTimer:
|
| 31 |
"""Timer for tracking pipeline execution steps"""
|
| 32 |
-
|
| 33 |
def __init__(self):
|
| 34 |
self.reset()
|
| 35 |
-
|
| 36 |
def reset(self):
|
| 37 |
"""Reset all timing data for a new request"""
|
| 38 |
self.start_time = time.time()
|
| 39 |
self.step_times = {}
|
| 40 |
self.step_start = None
|
| 41 |
self.current_step = None
|
| 42 |
-
|
| 43 |
@contextmanager
|
| 44 |
def time_step(self, step_name: str):
|
| 45 |
"""Context manager to time a specific step"""
|
|
@@ -49,18 +54,20 @@ class PipelineTimer:
|
|
| 49 |
yield
|
| 50 |
finally:
|
| 51 |
step_end = time.time()
|
| 52 |
-
self.step_times[step_name] = round(
|
|
|
|
|
|
|
| 53 |
self.current_step = None
|
| 54 |
-
|
| 55 |
def get_total_time(self):
|
| 56 |
"""Get total elapsed time since reset"""
|
| 57 |
return round((time.time() - self.start_time) * 1000, 2)
|
| 58 |
-
|
| 59 |
def get_timing_summary(self) -> Dict:
|
| 60 |
"""Get a summary of all timing data"""
|
| 61 |
total_time = self.get_total_time()
|
| 62 |
return {
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
}
|
|
|
|
| 2 |
Utilities module for XENO Bot
|
| 3 |
Handles logging and timing functionality
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import logging
|
| 7 |
import sys
|
| 8 |
import time
|
|
|
|
| 14 |
logging.basicConfig(
|
| 15 |
filename="app.log",
|
| 16 |
level=logging.INFO,
|
| 17 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 18 |
)
|
| 19 |
|
| 20 |
+
|
| 21 |
def log_exception(exc_type, exc_value, exc_traceback):
|
| 22 |
"""Log uncaught exceptions"""
|
| 23 |
if issubclass(exc_type, KeyboardInterrupt):
|
| 24 |
return
|
| 25 |
+
logging.critical(
|
| 26 |
+
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
|
| 30 |
sys.excepthook = log_exception
|
| 31 |
logging.info("App started successfully.")
|
|
|
|
| 34 |
# ===== Time Tracking Class =====
|
| 35 |
class PipelineTimer:
|
| 36 |
"""Timer for tracking pipeline execution steps"""
|
| 37 |
+
|
| 38 |
def __init__(self):
|
| 39 |
self.reset()
|
| 40 |
+
|
| 41 |
def reset(self):
|
| 42 |
"""Reset all timing data for a new request"""
|
| 43 |
self.start_time = time.time()
|
| 44 |
self.step_times = {}
|
| 45 |
self.step_start = None
|
| 46 |
self.current_step = None
|
| 47 |
+
|
| 48 |
@contextmanager
|
| 49 |
def time_step(self, step_name: str):
|
| 50 |
"""Context manager to time a specific step"""
|
|
|
|
| 54 |
yield
|
| 55 |
finally:
|
| 56 |
step_end = time.time()
|
| 57 |
+
self.step_times[step_name] = round(
|
| 58 |
+
(step_end - step_start) * 1000, 2
|
| 59 |
+
) # Convert to milliseconds
|
| 60 |
self.current_step = None
|
| 61 |
+
|
| 62 |
def get_total_time(self):
|
| 63 |
"""Get total elapsed time since reset"""
|
| 64 |
return round((time.time() - self.start_time) * 1000, 2)
|
| 65 |
+
|
| 66 |
def get_timing_summary(self) -> Dict:
|
| 67 |
"""Get a summary of all timing data"""
|
| 68 |
total_time = self.get_total_time()
|
| 69 |
return {
|
| 70 |
+
"total_time_ms": total_time,
|
| 71 |
+
"step_times": self.step_times,
|
| 72 |
+
"timestamp": datetime.now().isoformat(),
|
| 73 |
}
|
src/vector_store.py
CHANGED
|
@@ -2,38 +2,34 @@
|
|
| 2 |
Vector Store module for XENO Bot
|
| 3 |
Handles ChromaDB vector store operations
|
| 4 |
"""
|
|
|
|
|
|
|
|
|
|
| 5 |
import chromadb
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
from langchain_chroma import Chroma
|
| 9 |
from sentence_transformers import util
|
| 10 |
-
|
| 11 |
-
from
|
| 12 |
-
|
| 13 |
-
client,
|
| 14 |
-
COLLECTION_NAME,
|
| 15 |
-
CHROMA_DB_PATH,
|
| 16 |
-
RAG_TOP_K,
|
| 17 |
-
RAG_MAX_RESULTS,
|
| 18 |
-
EMBEDDING_MODEL
|
| 19 |
-
)
|
| 20 |
from src.knowledge_base import get_knowledge_base_data
|
| 21 |
|
| 22 |
|
| 23 |
def initialize_vector_store() -> Tuple[chromadb.Collection, Chroma, Any]:
|
| 24 |
"""
|
| 25 |
Initialize ChromaDB vector store
|
| 26 |
-
|
| 27 |
Returns:
|
| 28 |
Tuple of (collection, vector_store, retriever)
|
| 29 |
"""
|
| 30 |
# Get knowledge base data
|
| 31 |
documents, metadatas, ids = get_knowledge_base_data()
|
| 32 |
-
|
| 33 |
# Initialize ChromaDB client
|
| 34 |
try:
|
| 35 |
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
| 36 |
-
|
| 37 |
# Try to get existing collection
|
| 38 |
try:
|
| 39 |
collection = client.get_collection(name=COLLECTION_NAME)
|
|
@@ -43,30 +39,31 @@ def initialize_vector_store() -> Tuple[chromadb.Collection, Chroma, Any]:
|
|
| 43 |
print(f"Creating new ChromaDB collection: {COLLECTION_NAME}")
|
| 44 |
collection = client.create_collection(name=COLLECTION_NAME)
|
| 45 |
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
| 46 |
-
|
| 47 |
# Create vector store and retriever
|
| 48 |
vector_store = Chroma(client=client, collection_name=COLLECTION_NAME)
|
| 49 |
retriever = vector_store.as_retriever(
|
| 50 |
-
search_type="similarity",
|
| 51 |
-
search_kwargs={"k": RAG_TOP_K}
|
| 52 |
)
|
| 53 |
-
|
| 54 |
return collection, vector_store, retriever
|
| 55 |
-
|
| 56 |
except Exception as e:
|
| 57 |
print(f"Failed to initialize ChromaDB: {e}")
|
| 58 |
raise
|
| 59 |
|
| 60 |
|
| 61 |
-
def generate_embeddings(
|
|
|
|
|
|
|
| 62 |
"""
|
| 63 |
Generate embeddings for query and documents
|
| 64 |
-
|
| 65 |
Args:
|
| 66 |
query: User query
|
| 67 |
documents: List of retrieved documents
|
| 68 |
timer: Optional timer object for tracking
|
| 69 |
-
|
| 70 |
Returns:
|
| 71 |
Tuple of (query_embedding, doc_embeddings)
|
| 72 |
"""
|
|
@@ -77,38 +74,40 @@ def generate_embeddings(query: str, documents: List[Any], timer=None) -> Tuple[L
|
|
| 77 |
return _generate_embeddings_impl(query, documents)
|
| 78 |
|
| 79 |
|
| 80 |
-
def _generate_embeddings_impl(
|
|
|
|
|
|
|
| 81 |
"""Internal implementation of embedding generation"""
|
| 82 |
# 1. Update query embedding access
|
| 83 |
-
query_result =
|
| 84 |
-
model=EMBEDDING_MODEL,
|
| 85 |
-
contents=query
|
| 86 |
)
|
| 87 |
# The SDK returns an EmbedContentResponse object with an 'embeddings' attribute
|
| 88 |
-
query_embedding = query_result.embeddings[0].values
|
| 89 |
-
|
| 90 |
# 2. Update document embeddings access
|
| 91 |
doc_contents = [doc.page_content for doc in documents]
|
| 92 |
-
doc_results =
|
| 93 |
-
model=EMBEDDING_MODEL,
|
| 94 |
-
contents=doc_contents
|
| 95 |
)
|
| 96 |
-
|
| 97 |
# Map the list of embedding objects to a list of vector values
|
| 98 |
doc_embeddings = [e.values for e in doc_results.embeddings]
|
| 99 |
-
|
| 100 |
return query_embedding, doc_embeddings
|
| 101 |
|
| 102 |
|
| 103 |
-
def calculate_similarity(
|
|
|
|
|
|
|
| 104 |
"""
|
| 105 |
Calculate cosine similarity between query and documents
|
| 106 |
-
|
| 107 |
Args:
|
| 108 |
query_embedding: Query embedding vector
|
| 109 |
doc_embeddings: List of document embedding vectors
|
| 110 |
timer: Optional timer object for tracking
|
| 111 |
-
|
| 112 |
Returns:
|
| 113 |
List of cosine similarity scores
|
| 114 |
"""
|
|
@@ -119,27 +118,32 @@ def calculate_similarity(query_embedding: List[float], doc_embeddings: List[List
|
|
| 119 |
return _calculate_similarity_impl(query_embedding, doc_embeddings)
|
| 120 |
|
| 121 |
|
| 122 |
-
def _calculate_similarity_impl(
|
|
|
|
|
|
|
| 123 |
"""Internal implementation of similarity calculation"""
|
| 124 |
cosine_scores = util.cos_sim(
|
| 125 |
-
torch.tensor(query_embedding).float(),
|
| 126 |
-
torch.tensor(doc_embeddings).float()
|
| 127 |
)[0].tolist()
|
| 128 |
-
|
| 129 |
return cosine_scores
|
| 130 |
|
| 131 |
|
| 132 |
-
def process_context(
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
"""
|
| 135 |
Process retrieved context and format for LLM
|
| 136 |
-
|
| 137 |
Args:
|
| 138 |
results: List of retrieved documents
|
| 139 |
cosine_scores: List of similarity scores
|
| 140 |
max_results: Maximum number of results to include
|
| 141 |
timer: Optional timer object for tracking
|
| 142 |
-
|
| 143 |
Returns:
|
| 144 |
Tuple of (formatted_context, source_ids, knowledge_pairs)
|
| 145 |
"""
|
|
@@ -150,28 +154,29 @@ def process_context(results: List[Any], cosine_scores: List[float],
|
|
| 150 |
return _process_context_impl(results, cosine_scores, max_results)
|
| 151 |
|
| 152 |
|
| 153 |
-
def _process_context_impl(
|
| 154 |
-
|
|
|
|
| 155 |
"""Internal implementation of context processing"""
|
| 156 |
sorted_indices = np.argsort(cosine_scores)[::-1][:max_results]
|
| 157 |
-
|
| 158 |
formatted_context = ""
|
| 159 |
source_ids = []
|
| 160 |
knowledge_pairs = []
|
| 161 |
-
|
| 162 |
for i, idx in enumerate(sorted_indices, 1):
|
| 163 |
result = results[idx]
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
question = result.metadata.get(
|
| 167 |
-
answer = result.metadata.get(
|
| 168 |
-
|
| 169 |
formatted_context += f"Knowledge Entry {i}:\n"
|
| 170 |
formatted_context += f"Q: {question}\n"
|
| 171 |
formatted_context += f"A: {answer}\n"
|
| 172 |
formatted_context += "-" * 40 + "\n"
|
| 173 |
-
|
| 174 |
-
source_ids.append(result.metadata.get(
|
| 175 |
knowledge_pairs.append((question, answer))
|
| 176 |
-
|
| 177 |
return formatted_context, source_ids, knowledge_pairs
|
|
|
|
| 2 |
Vector Store module for XENO Bot
|
| 3 |
Handles ChromaDB vector store operations
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any, List, Tuple
|
| 7 |
+
|
| 8 |
import chromadb
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
from langchain_chroma import Chroma
|
| 12 |
from sentence_transformers import util
|
| 13 |
+
|
| 14 |
+
from src.config import (CHROMA_DB_PATH, COLLECTION_NAME, EMBEDDING_MODEL,
|
| 15 |
+
RAG_MAX_RESULTS, RAG_TOP_K, genai_client)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from src.knowledge_base import get_knowledge_base_data
|
| 17 |
|
| 18 |
|
| 19 |
def initialize_vector_store() -> Tuple[chromadb.Collection, Chroma, Any]:
|
| 20 |
"""
|
| 21 |
Initialize ChromaDB vector store
|
| 22 |
+
|
| 23 |
Returns:
|
| 24 |
Tuple of (collection, vector_store, retriever)
|
| 25 |
"""
|
| 26 |
# Get knowledge base data
|
| 27 |
documents, metadatas, ids = get_knowledge_base_data()
|
| 28 |
+
|
| 29 |
# Initialize ChromaDB client
|
| 30 |
try:
|
| 31 |
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
| 32 |
+
|
| 33 |
# Try to get existing collection
|
| 34 |
try:
|
| 35 |
collection = client.get_collection(name=COLLECTION_NAME)
|
|
|
|
| 39 |
print(f"Creating new ChromaDB collection: {COLLECTION_NAME}")
|
| 40 |
collection = client.create_collection(name=COLLECTION_NAME)
|
| 41 |
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
| 42 |
+
|
| 43 |
# Create vector store and retriever
|
| 44 |
vector_store = Chroma(client=client, collection_name=COLLECTION_NAME)
|
| 45 |
retriever = vector_store.as_retriever(
|
| 46 |
+
search_type="similarity", search_kwargs={"k": RAG_TOP_K}
|
|
|
|
| 47 |
)
|
| 48 |
+
|
| 49 |
return collection, vector_store, retriever
|
| 50 |
+
|
| 51 |
except Exception as e:
|
| 52 |
print(f"Failed to initialize ChromaDB: {e}")
|
| 53 |
raise
|
| 54 |
|
| 55 |
|
| 56 |
+
def generate_embeddings(
|
| 57 |
+
query: str, documents: List[Any], timer=None
|
| 58 |
+
) -> Tuple[List[float], List[List[float]]]:
|
| 59 |
"""
|
| 60 |
Generate embeddings for query and documents
|
| 61 |
+
|
| 62 |
Args:
|
| 63 |
query: User query
|
| 64 |
documents: List of retrieved documents
|
| 65 |
timer: Optional timer object for tracking
|
| 66 |
+
|
| 67 |
Returns:
|
| 68 |
Tuple of (query_embedding, doc_embeddings)
|
| 69 |
"""
|
|
|
|
| 74 |
return _generate_embeddings_impl(query, documents)
|
| 75 |
|
| 76 |
|
| 77 |
+
def _generate_embeddings_impl(
|
| 78 |
+
query: str, documents: List[Any]
|
| 79 |
+
) -> Tuple[List[float], List[List[float]]]:
|
| 80 |
"""Internal implementation of embedding generation"""
|
| 81 |
# 1. Update query embedding access
|
| 82 |
+
query_result = genai_client.models.embed_content(
|
| 83 |
+
model=EMBEDDING_MODEL, contents=query
|
|
|
|
| 84 |
)
|
| 85 |
# The SDK returns an EmbedContentResponse object with an 'embeddings' attribute
|
| 86 |
+
query_embedding = query_result.embeddings[0].values
|
| 87 |
+
|
| 88 |
# 2. Update document embeddings access
|
| 89 |
doc_contents = [doc.page_content for doc in documents]
|
| 90 |
+
doc_results = genai_client.models.embed_content(
|
| 91 |
+
model=EMBEDDING_MODEL, contents=doc_contents
|
|
|
|
| 92 |
)
|
| 93 |
+
|
| 94 |
# Map the list of embedding objects to a list of vector values
|
| 95 |
doc_embeddings = [e.values for e in doc_results.embeddings]
|
| 96 |
+
|
| 97 |
return query_embedding, doc_embeddings
|
| 98 |
|
| 99 |
|
| 100 |
+
def calculate_similarity(
|
| 101 |
+
query_embedding: List[float], doc_embeddings: List[List[float]], timer=None
|
| 102 |
+
) -> List[float]:
|
| 103 |
"""
|
| 104 |
Calculate cosine similarity between query and documents
|
| 105 |
+
|
| 106 |
Args:
|
| 107 |
query_embedding: Query embedding vector
|
| 108 |
doc_embeddings: List of document embedding vectors
|
| 109 |
timer: Optional timer object for tracking
|
| 110 |
+
|
| 111 |
Returns:
|
| 112 |
List of cosine similarity scores
|
| 113 |
"""
|
|
|
|
| 118 |
return _calculate_similarity_impl(query_embedding, doc_embeddings)
|
| 119 |
|
| 120 |
|
| 121 |
+
def _calculate_similarity_impl(
|
| 122 |
+
query_embedding: List[float], doc_embeddings: List[List[float]]
|
| 123 |
+
) -> List[float]:
|
| 124 |
"""Internal implementation of similarity calculation"""
|
| 125 |
cosine_scores = util.cos_sim(
|
| 126 |
+
torch.tensor(query_embedding).float(), torch.tensor(doc_embeddings).float()
|
|
|
|
| 127 |
)[0].tolist()
|
| 128 |
+
|
| 129 |
return cosine_scores
|
| 130 |
|
| 131 |
|
| 132 |
+
def process_context(
|
| 133 |
+
results: List[Any],
|
| 134 |
+
cosine_scores: List[float],
|
| 135 |
+
max_results: int = RAG_MAX_RESULTS,
|
| 136 |
+
timer=None,
|
| 137 |
+
) -> Tuple[str, List[str], List[Tuple[str, str]]]:
|
| 138 |
"""
|
| 139 |
Process retrieved context and format for LLM
|
| 140 |
+
|
| 141 |
Args:
|
| 142 |
results: List of retrieved documents
|
| 143 |
cosine_scores: List of similarity scores
|
| 144 |
max_results: Maximum number of results to include
|
| 145 |
timer: Optional timer object for tracking
|
| 146 |
+
|
| 147 |
Returns:
|
| 148 |
Tuple of (formatted_context, source_ids, knowledge_pairs)
|
| 149 |
"""
|
|
|
|
| 154 |
return _process_context_impl(results, cosine_scores, max_results)
|
| 155 |
|
| 156 |
|
| 157 |
+
def _process_context_impl(
|
| 158 |
+
results: List[Any], cosine_scores: List[float], max_results: int
|
| 159 |
+
) -> Tuple[str, List[str], List[Tuple[str, str]]]:
|
| 160 |
"""Internal implementation of context processing"""
|
| 161 |
sorted_indices = np.argsort(cosine_scores)[::-1][:max_results]
|
| 162 |
+
|
| 163 |
formatted_context = ""
|
| 164 |
source_ids = []
|
| 165 |
knowledge_pairs = []
|
| 166 |
+
|
| 167 |
for i, idx in enumerate(sorted_indices, 1):
|
| 168 |
result = results[idx]
|
| 169 |
+
cosine_scores[idx]
|
| 170 |
+
|
| 171 |
+
question = result.metadata.get("question", "N/A")
|
| 172 |
+
answer = result.metadata.get("content", "N/A")
|
| 173 |
+
|
| 174 |
formatted_context += f"Knowledge Entry {i}:\n"
|
| 175 |
formatted_context += f"Q: {question}\n"
|
| 176 |
formatted_context += f"A: {answer}\n"
|
| 177 |
formatted_context += "-" * 40 + "\n"
|
| 178 |
+
|
| 179 |
+
source_ids.append(result.metadata.get("id", "N/A"))
|
| 180 |
knowledge_pairs.append((question, answer))
|
| 181 |
+
|
| 182 |
return formatted_context, source_ids, knowledge_pairs
|
tests/conftest.py
CHANGED
|
@@ -2,16 +2,18 @@
|
|
| 2 |
Pytest configuration file
|
| 3 |
Sets up test environment and fixtures
|
| 4 |
"""
|
|
|
|
| 5 |
import os
|
| 6 |
import sys
|
|
|
|
|
|
|
| 7 |
import pytest
|
| 8 |
-
from unittest.mock import Mock, MagicMock, patch, PropertyMock
|
| 9 |
|
| 10 |
# Add src to path
|
| 11 |
-
sys.path.insert(0, os.path.join(os.path.dirname(__file__),
|
| 12 |
|
| 13 |
# Set mock environment variables before importing any modules
|
| 14 |
-
os.environ.setdefault(
|
| 15 |
|
| 16 |
# Mock Google Sheets credentials
|
| 17 |
mock_credentials = {
|
|
@@ -24,20 +26,23 @@ mock_credentials = {
|
|
| 24 |
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 25 |
"token_uri": "https://oauth2.googleapis.com/token",
|
| 26 |
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
| 27 |
-
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test"
|
| 28 |
}
|
| 29 |
import json
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
# Mock google.oauth2 and gspread modules before src.logger imports them
|
| 33 |
mock_credentials_class = MagicMock()
|
| 34 |
mock_creds_instance = MagicMock()
|
| 35 |
-
mock_credentials_class.from_service_account_info = Mock(
|
|
|
|
|
|
|
| 36 |
|
| 37 |
mock_oauth2 = MagicMock()
|
| 38 |
mock_oauth2.service_account.Credentials = mock_credentials_class
|
| 39 |
-
sys.modules[
|
| 40 |
-
sys.modules[
|
| 41 |
|
| 42 |
mock_gspread = MagicMock()
|
| 43 |
mock_spreadsheet = MagicMock()
|
|
@@ -49,14 +54,15 @@ mock_spreadsheet.add_worksheet = Mock(return_value=mock_worksheet)
|
|
| 49 |
mock_client = MagicMock()
|
| 50 |
mock_client.open = Mock(return_value=mock_spreadsheet)
|
| 51 |
mock_gspread.authorize = Mock(return_value=mock_client)
|
| 52 |
-
sys.modules[
|
| 53 |
|
| 54 |
|
| 55 |
@pytest.fixture(autouse=True)
|
| 56 |
def mock_google_sheets():
|
| 57 |
"""Mock Google Sheets to avoid actual connections during testing"""
|
| 58 |
-
with patch(
|
| 59 |
-
|
|
|
|
| 60 |
mock_response.append_row = Mock()
|
| 61 |
mock_timing.append_row = Mock()
|
| 62 |
yield mock_response, mock_timing
|
|
@@ -65,20 +71,16 @@ def mock_google_sheets():
|
|
| 65 |
@pytest.fixture
|
| 66 |
def mock_genai():
|
| 67 |
"""Mock Google Generative AI"""
|
| 68 |
-
with patch(
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
yield {
|
| 72 |
-
'configure': mock_config,
|
| 73 |
-
'model': mock_model,
|
| 74 |
-
'embed': mock_embed
|
| 75 |
-
}
|
| 76 |
|
| 77 |
|
| 78 |
@pytest.fixture
|
| 79 |
def mock_chromadb():
|
| 80 |
"""Mock ChromaDB client"""
|
| 81 |
-
with patch(
|
| 82 |
mock_collection = Mock()
|
| 83 |
mock_client.return_value.get_collection.return_value = mock_collection
|
| 84 |
yield mock_client
|
|
@@ -87,7 +89,7 @@ def mock_chromadb():
|
|
| 87 |
@pytest.fixture
|
| 88 |
def mock_sqlite():
|
| 89 |
"""Mock SQLite connections for memory"""
|
| 90 |
-
with patch(
|
| 91 |
mock_conn = Mock()
|
| 92 |
mock_connect.return_value = mock_conn
|
| 93 |
yield mock_conn
|
|
@@ -97,21 +99,42 @@ def mock_sqlite():
|
|
| 97 |
def sample_documents():
|
| 98 |
"""Provide sample documents for testing"""
|
| 99 |
doc1 = Mock()
|
| 100 |
-
doc1.page_content =
|
|
|
|
|
|
|
| 101 |
doc1.metadata = {
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
}
|
| 107 |
-
|
| 108 |
doc2 = Mock()
|
| 109 |
doc2.page_content = "Question: What are the fees?\nAnswer: 1% per transaction."
|
| 110 |
doc2.metadata = {
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
}
|
| 116 |
-
|
| 117 |
return [doc1, doc2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
Pytest configuration file
|
| 3 |
Sets up test environment and fixtures
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 9 |
+
|
| 10 |
import pytest
|
|
|
|
| 11 |
|
| 12 |
# Add src to path
|
| 13 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 14 |
|
| 15 |
# Set mock environment variables before importing any modules
|
| 16 |
+
os.environ.setdefault("GEMINI_API_KEY", "test-api-key-12345")
|
| 17 |
|
| 18 |
# Mock Google Sheets credentials
|
| 19 |
mock_credentials = {
|
|
|
|
| 26 |
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 27 |
"token_uri": "https://oauth2.googleapis.com/token",
|
| 28 |
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
| 29 |
+
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test",
|
| 30 |
}
|
| 31 |
import json
|
| 32 |
+
|
| 33 |
+
os.environ.setdefault("GOOGLE_SHEETS_CREDENTIALS", json.dumps(mock_credentials))
|
| 34 |
|
| 35 |
# Mock google.oauth2 and gspread modules before src.logger imports them
|
| 36 |
mock_credentials_class = MagicMock()
|
| 37 |
mock_creds_instance = MagicMock()
|
| 38 |
+
mock_credentials_class.from_service_account_info = Mock(
|
| 39 |
+
return_value=mock_creds_instance
|
| 40 |
+
)
|
| 41 |
|
| 42 |
mock_oauth2 = MagicMock()
|
| 43 |
mock_oauth2.service_account.Credentials = mock_credentials_class
|
| 44 |
+
sys.modules["google.oauth2"] = mock_oauth2
|
| 45 |
+
sys.modules["google.oauth2.service_account"] = mock_oauth2.service_account
|
| 46 |
|
| 47 |
mock_gspread = MagicMock()
|
| 48 |
mock_spreadsheet = MagicMock()
|
|
|
|
| 54 |
mock_client = MagicMock()
|
| 55 |
mock_client.open = Mock(return_value=mock_spreadsheet)
|
| 56 |
mock_gspread.authorize = Mock(return_value=mock_client)
|
| 57 |
+
sys.modules["gspread"] = mock_gspread
|
| 58 |
|
| 59 |
|
| 60 |
@pytest.fixture(autouse=True)
|
| 61 |
def mock_google_sheets():
|
| 62 |
"""Mock Google Sheets to avoid actual connections during testing"""
|
| 63 |
+
with patch("src.logger.response_sheet") as mock_response, patch(
|
| 64 |
+
"src.logger.timing_sheet"
|
| 65 |
+
) as mock_timing:
|
| 66 |
mock_response.append_row = Mock()
|
| 67 |
mock_timing.append_row = Mock()
|
| 68 |
yield mock_response, mock_timing
|
|
|
|
| 71 |
@pytest.fixture
|
| 72 |
def mock_genai():
|
| 73 |
"""Mock Google Generative AI"""
|
| 74 |
+
with patch("google.generativeai.configure") as mock_config, patch(
|
| 75 |
+
"google.generativeai.GenerativeModel"
|
| 76 |
+
) as mock_model, patch("google.generativeai.embed_content") as mock_embed:
|
| 77 |
+
yield {"configure": mock_config, "model": mock_model, "embed": mock_embed}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
@pytest.fixture
|
| 81 |
def mock_chromadb():
|
| 82 |
"""Mock ChromaDB client"""
|
| 83 |
+
with patch("chromadb.PersistentClient") as mock_client:
|
| 84 |
mock_collection = Mock()
|
| 85 |
mock_client.return_value.get_collection.return_value = mock_collection
|
| 86 |
yield mock_client
|
|
|
|
| 89 |
@pytest.fixture
|
| 90 |
def mock_sqlite():
|
| 91 |
"""Mock SQLite connections for memory"""
|
| 92 |
+
with patch("sqlite3.connect") as mock_connect:
|
| 93 |
mock_conn = Mock()
|
| 94 |
mock_connect.return_value = mock_conn
|
| 95 |
yield mock_conn
|
|
|
|
| 99 |
def sample_documents():
|
| 100 |
"""Provide sample documents for testing"""
|
| 101 |
doc1 = Mock()
|
| 102 |
+
doc1.page_content = (
|
| 103 |
+
"Question: How do I create an account?\nAnswer: Visit our website."
|
| 104 |
+
)
|
| 105 |
doc1.metadata = {
|
| 106 |
+
"id": "KB001",
|
| 107 |
+
"question": "How do I create an account?",
|
| 108 |
+
"content": "Visit our website.",
|
| 109 |
+
"section": "Account Management",
|
| 110 |
}
|
| 111 |
+
|
| 112 |
doc2 = Mock()
|
| 113 |
doc2.page_content = "Question: What are the fees?\nAnswer: 1% per transaction."
|
| 114 |
doc2.metadata = {
|
| 115 |
+
"id": "KB002",
|
| 116 |
+
"question": "What are the fees?",
|
| 117 |
+
"content": "1% per transaction.",
|
| 118 |
+
"section": "Fees",
|
| 119 |
}
|
| 120 |
+
|
| 121 |
return [doc1, doc2]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@pytest.fixture
|
| 125 |
+
def mock_genai_client():
|
| 126 |
+
"""Mock Google Generative AI client with new SDK structure"""
|
| 127 |
+
with patch("src.config.genai_client") as mock_client:
|
| 128 |
+
# Mock generate_content for LLM
|
| 129 |
+
mock_generate_response = Mock()
|
| 130 |
+
mock_generate_response.text = "Test response from LLM"
|
| 131 |
+
mock_client.models.generate_content.return_value = mock_generate_response
|
| 132 |
+
|
| 133 |
+
# Mock embed_content for embeddings
|
| 134 |
+
mock_embedding = Mock()
|
| 135 |
+
mock_embedding.values = [0.1, 0.2, 0.3]
|
| 136 |
+
mock_embed_response = Mock()
|
| 137 |
+
mock_embed_response.embeddings = [mock_embedding]
|
| 138 |
+
mock_client.models.embed_content.return_value = mock_embed_response
|
| 139 |
+
|
| 140 |
+
yield mock_client
|
tests/test_app.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for app module
|
| 3 |
+
Tests main orchestration logic
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import unittest
|
| 7 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 8 |
+
|
| 9 |
+
from app import get_context_and_answer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestApp(unittest.TestCase):
|
| 13 |
+
"""Test cases for app module"""
|
| 14 |
+
|
| 15 |
+
def setUp(self):
|
| 16 |
+
"""Set up test fixtures"""
|
| 17 |
+
self.message = "How do I create an account?"
|
| 18 |
+
self.history = [["Previous question", "Previous answer"]]
|
| 19 |
+
self.session_id = "test-session-123"
|
| 20 |
+
self.mock_intent_classifier = Mock()
|
| 21 |
+
self.mock_retriever = Mock()
|
| 22 |
+
|
| 23 |
+
@patch("app.log_timing_data")
|
| 24 |
+
@patch("app.log_response")
|
| 25 |
+
@patch("app.update_memory")
|
| 26 |
+
@patch("app.retrieve_memory")
|
| 27 |
+
@patch("app.create_session_config")
|
| 28 |
+
def test_get_context_and_answer_simple_intent(
|
| 29 |
+
self,
|
| 30 |
+
mock_session_config,
|
| 31 |
+
mock_retrieve_memory,
|
| 32 |
+
mock_update_memory,
|
| 33 |
+
mock_log_response,
|
| 34 |
+
mock_log_timing,
|
| 35 |
+
):
|
| 36 |
+
"""Test get_context_and_answer with simple intent (greeting)"""
|
| 37 |
+
# Setup mocks
|
| 38 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 39 |
+
mock_retrieve_memory.return_value = []
|
| 40 |
+
self.mock_intent_classifier.classify_intent.return_value = (
|
| 41 |
+
"greeting",
|
| 42 |
+
"Hello! How can I help you?",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Call function
|
| 46 |
+
answer = get_context_and_answer(
|
| 47 |
+
"Hello",
|
| 48 |
+
self.history,
|
| 49 |
+
self.session_id,
|
| 50 |
+
self.mock_intent_classifier,
|
| 51 |
+
self.mock_retriever,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Verify intent was classified
|
| 55 |
+
self.mock_intent_classifier.classify_intent.assert_called_once_with("Hello")
|
| 56 |
+
|
| 57 |
+
# Should not use retriever for simple intent
|
| 58 |
+
self.mock_retriever.invoke.assert_not_called()
|
| 59 |
+
|
| 60 |
+
# Verify response
|
| 61 |
+
self.assertEqual(answer, "Hello! How can I help you?")
|
| 62 |
+
|
| 63 |
+
# Verify memory was updated
|
| 64 |
+
mock_update_memory.assert_called_once()
|
| 65 |
+
|
| 66 |
+
# Verify logging
|
| 67 |
+
mock_log_response.assert_called_once()
|
| 68 |
+
mock_log_timing.assert_called_once()
|
| 69 |
+
|
| 70 |
+
@patch("app.generate_xeno_response")
|
| 71 |
+
@patch("app.process_context")
|
| 72 |
+
@patch("app.generate_embeddings")
|
| 73 |
+
@patch("app.log_timing_data")
|
| 74 |
+
@patch("app.log_response")
|
| 75 |
+
@patch("app.update_memory")
|
| 76 |
+
@patch("app.retrieve_memory")
|
| 77 |
+
@patch("app.create_session_config")
|
| 78 |
+
def test_get_context_and_answer_query_intent(
|
| 79 |
+
self,
|
| 80 |
+
mock_session_config,
|
| 81 |
+
mock_retrieve_memory,
|
| 82 |
+
mock_update_memory,
|
| 83 |
+
mock_log_response,
|
| 84 |
+
mock_log_timing,
|
| 85 |
+
mock_generate_embeddings,
|
| 86 |
+
mock_process_context,
|
| 87 |
+
mock_generate_response,
|
| 88 |
+
):
|
| 89 |
+
"""Test get_context_and_answer with query intent"""
|
| 90 |
+
# Setup mocks
|
| 91 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 92 |
+
mock_retrieve_memory.return_value = []
|
| 93 |
+
self.mock_intent_classifier.classify_intent.return_value = ("query", None)
|
| 94 |
+
|
| 95 |
+
# Mock retriever
|
| 96 |
+
mock_doc = Mock()
|
| 97 |
+
mock_doc.page_content = "Test content"
|
| 98 |
+
mock_doc.metadata = {"id": "KB001", "question": "Q", "content": "A"}
|
| 99 |
+
self.mock_retriever.invoke.return_value = [mock_doc]
|
| 100 |
+
|
| 101 |
+
# Mock embeddings
|
| 102 |
+
mock_generate_embeddings.return_value = (
|
| 103 |
+
[0.1, 0.2, 0.3], # query embedding
|
| 104 |
+
[[0.2, 0.3, 0.4]], # doc embeddings
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Mock context processing
|
| 108 |
+
mock_process_context.return_value = (
|
| 109 |
+
"Formatted context",
|
| 110 |
+
["KB001"],
|
| 111 |
+
[("Q", "A")],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Mock LLM response
|
| 115 |
+
mock_generate_response.return_value = "Generated answer"
|
| 116 |
+
|
| 117 |
+
# Call function
|
| 118 |
+
answer = get_context_and_answer(
|
| 119 |
+
self.message,
|
| 120 |
+
self.history,
|
| 121 |
+
self.session_id,
|
| 122 |
+
self.mock_intent_classifier,
|
| 123 |
+
self.mock_retriever,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Verify RAG pipeline was executed
|
| 127 |
+
self.mock_retriever.invoke.assert_called_once_with(self.message)
|
| 128 |
+
mock_generate_embeddings.assert_called_once()
|
| 129 |
+
mock_process_context.assert_called_once()
|
| 130 |
+
mock_generate_response.assert_called_once()
|
| 131 |
+
|
| 132 |
+
# Verify response
|
| 133 |
+
self.assertEqual(answer, "Generated answer")
|
| 134 |
+
|
| 135 |
+
# Verify logging
|
| 136 |
+
mock_log_response.assert_called_once()
|
| 137 |
+
mock_log_timing.assert_called_once()
|
| 138 |
+
|
| 139 |
+
@patch("app.log_timing_data")
|
| 140 |
+
@patch("app.log_response")
|
| 141 |
+
@patch("app.update_memory")
|
| 142 |
+
@patch("app.retrieve_memory")
|
| 143 |
+
@patch("app.create_session_config")
|
| 144 |
+
def test_get_context_and_answer_short_message(
|
| 145 |
+
self,
|
| 146 |
+
mock_session_config,
|
| 147 |
+
mock_retrieve_memory,
|
| 148 |
+
mock_update_memory,
|
| 149 |
+
mock_log_response,
|
| 150 |
+
mock_log_timing,
|
| 151 |
+
):
|
| 152 |
+
"""Test get_context_and_answer with very short message"""
|
| 153 |
+
# Setup mocks
|
| 154 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 155 |
+
mock_retrieve_memory.return_value = []
|
| 156 |
+
self.mock_intent_classifier.classify_intent.return_value = ("query", None)
|
| 157 |
+
|
| 158 |
+
# Call function with short message
|
| 159 |
+
answer = get_context_and_answer(
|
| 160 |
+
"Hi",
|
| 161 |
+
self.history,
|
| 162 |
+
self.session_id,
|
| 163 |
+
self.mock_intent_classifier,
|
| 164 |
+
self.mock_retriever,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Should return a request for more details
|
| 168 |
+
self.assertIn("more details", answer)
|
| 169 |
+
|
| 170 |
+
# Should not invoke retriever
|
| 171 |
+
self.mock_retriever.invoke.assert_not_called()
|
| 172 |
+
|
| 173 |
+
@patch("app.generate_embeddings")
|
| 174 |
+
@patch("app.log_timing_data")
|
| 175 |
+
@patch("app.log_response")
|
| 176 |
+
@patch("app.update_memory")
|
| 177 |
+
@patch("app.retrieve_memory")
|
| 178 |
+
@patch("app.create_session_config")
|
| 179 |
+
def test_get_context_and_answer_low_similarity(
|
| 180 |
+
self,
|
| 181 |
+
mock_session_config,
|
| 182 |
+
mock_retrieve_memory,
|
| 183 |
+
mock_update_memory,
|
| 184 |
+
mock_log_response,
|
| 185 |
+
mock_log_timing,
|
| 186 |
+
mock_generate_embeddings,
|
| 187 |
+
):
|
| 188 |
+
"""Test get_context_and_answer with low similarity score"""
|
| 189 |
+
# Setup mocks
|
| 190 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 191 |
+
mock_retrieve_memory.return_value = []
|
| 192 |
+
self.mock_intent_classifier.classify_intent.return_value = ("query", None)
|
| 193 |
+
|
| 194 |
+
# Mock retriever
|
| 195 |
+
mock_doc = Mock()
|
| 196 |
+
mock_doc.page_content = "Test content"
|
| 197 |
+
self.mock_retriever.invoke.return_value = [mock_doc]
|
| 198 |
+
|
| 199 |
+
# Mock embeddings with low similarity
|
| 200 |
+
mock_generate_embeddings.return_value = (
|
| 201 |
+
[0.1, 0.2, 0.3],
|
| 202 |
+
[[1.0, 0.0, 0.0]], # Will result in low cosine score
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Call function
|
| 206 |
+
answer = get_context_and_answer(
|
| 207 |
+
"Some random question",
|
| 208 |
+
self.history,
|
| 209 |
+
self.session_id,
|
| 210 |
+
self.mock_intent_classifier,
|
| 211 |
+
self.mock_retriever,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Should return "couldn't find" message
|
| 215 |
+
self.assertIn("couldn't find", answer)
|
| 216 |
+
|
| 217 |
+
@patch("app.log_timing_data")
|
| 218 |
+
@patch("app.log_response")
|
| 219 |
+
@patch("app.update_memory")
|
| 220 |
+
@patch("app.retrieve_memory")
|
| 221 |
+
@patch("app.create_session_config")
|
| 222 |
+
def test_get_context_and_answer_rag_error(
|
| 223 |
+
self,
|
| 224 |
+
mock_session_config,
|
| 225 |
+
mock_retrieve_memory,
|
| 226 |
+
mock_update_memory,
|
| 227 |
+
mock_log_response,
|
| 228 |
+
mock_log_timing,
|
| 229 |
+
):
|
| 230 |
+
"""Test get_context_and_answer handles RAG errors gracefully"""
|
| 231 |
+
# Setup mocks
|
| 232 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 233 |
+
mock_retrieve_memory.return_value = []
|
| 234 |
+
self.mock_intent_classifier.classify_intent.return_value = ("query", None)
|
| 235 |
+
|
| 236 |
+
# Mock retriever to raise exception
|
| 237 |
+
self.mock_retriever.invoke.side_effect = Exception("Database error")
|
| 238 |
+
|
| 239 |
+
# Call function
|
| 240 |
+
answer = get_context_and_answer(
|
| 241 |
+
self.message,
|
| 242 |
+
self.history,
|
| 243 |
+
self.session_id,
|
| 244 |
+
self.mock_intent_classifier,
|
| 245 |
+
self.mock_retriever,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Should return technical issue message
|
| 249 |
+
self.assertIn("technical issue", answer)
|
| 250 |
+
|
| 251 |
+
# Verify error was logged
|
| 252 |
+
mock_log_timing.assert_called_once()
|
| 253 |
+
call_kwargs = mock_log_timing.call_args[1]
|
| 254 |
+
self.assertIsNotNone(call_kwargs.get("error_step"))
|
| 255 |
+
|
| 256 |
+
@patch("app.log_timing_data")
|
| 257 |
+
@patch("app.update_memory")
|
| 258 |
+
@patch("app.retrieve_memory")
|
| 259 |
+
@patch("app.create_session_config")
|
| 260 |
+
def test_get_context_and_answer_main_error(
|
| 261 |
+
self,
|
| 262 |
+
mock_session_config,
|
| 263 |
+
mock_retrieve_memory,
|
| 264 |
+
mock_update_memory,
|
| 265 |
+
mock_log_timing,
|
| 266 |
+
):
|
| 267 |
+
"""Test get_context_and_answer handles main pipeline errors"""
|
| 268 |
+
# Setup mocks
|
| 269 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 270 |
+
mock_retrieve_memory.side_effect = Exception("Memory error")
|
| 271 |
+
|
| 272 |
+
# Call function
|
| 273 |
+
answer = get_context_and_answer(
|
| 274 |
+
self.message,
|
| 275 |
+
self.history,
|
| 276 |
+
self.session_id,
|
| 277 |
+
self.mock_intent_classifier,
|
| 278 |
+
self.mock_retriever,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Should return error message
|
| 282 |
+
self.assertIn("error", answer)
|
| 283 |
+
|
| 284 |
+
# Verify error was logged
|
| 285 |
+
mock_log_timing.assert_called_once()
|
| 286 |
+
|
| 287 |
+
@patch("app.generate_xeno_response")
|
| 288 |
+
@patch("app.process_context")
|
| 289 |
+
@patch("app.generate_embeddings")
|
| 290 |
+
@patch("app.log_timing_data")
|
| 291 |
+
@patch("app.log_response")
|
| 292 |
+
@patch("app.update_memory")
|
| 293 |
+
@patch("app.retrieve_memory")
|
| 294 |
+
@patch("app.create_session_config")
|
| 295 |
+
def test_get_context_and_answer_with_chat_history(
|
| 296 |
+
self,
|
| 297 |
+
mock_session_config,
|
| 298 |
+
mock_retrieve_memory,
|
| 299 |
+
mock_update_memory,
|
| 300 |
+
mock_log_response,
|
| 301 |
+
mock_log_timing,
|
| 302 |
+
mock_generate_embeddings,
|
| 303 |
+
mock_process_context,
|
| 304 |
+
mock_generate_response,
|
| 305 |
+
):
|
| 306 |
+
"""Test get_context_and_answer passes chat history to LLM"""
|
| 307 |
+
# Setup mocks
|
| 308 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 309 |
+
chat_history = [
|
| 310 |
+
{"role": "user", "content": "Previous question"},
|
| 311 |
+
{"role": "assistant", "content": "Previous answer"},
|
| 312 |
+
]
|
| 313 |
+
mock_retrieve_memory.return_value = chat_history
|
| 314 |
+
self.mock_intent_classifier.classify_intent.return_value = ("query", None)
|
| 315 |
+
|
| 316 |
+
# Mock retriever
|
| 317 |
+
mock_doc = Mock()
|
| 318 |
+
mock_doc.page_content = "Test content"
|
| 319 |
+
mock_doc.metadata = {"id": "KB001", "question": "Q", "content": "A"}
|
| 320 |
+
self.mock_retriever.invoke.return_value = [mock_doc]
|
| 321 |
+
|
| 322 |
+
# Mock embeddings
|
| 323 |
+
mock_generate_embeddings.return_value = ([0.1, 0.2], [[0.9, 0.1]])
|
| 324 |
+
|
| 325 |
+
# Mock context processing
|
| 326 |
+
mock_process_context.return_value = ("Context", ["KB001"], [("Q", "A")])
|
| 327 |
+
|
| 328 |
+
# Mock LLM response
|
| 329 |
+
mock_generate_response.return_value = "Answer with context"
|
| 330 |
+
|
| 331 |
+
# Call function
|
| 332 |
+
answer = get_context_and_answer(
|
| 333 |
+
self.message,
|
| 334 |
+
self.history,
|
| 335 |
+
self.session_id,
|
| 336 |
+
self.mock_intent_classifier,
|
| 337 |
+
self.mock_retriever,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Verify chat history was passed to LLM
|
| 341 |
+
mock_generate_response.assert_called_once()
|
| 342 |
+
call_args = mock_generate_response.call_args[0]
|
| 343 |
+
self.assertEqual(call_args[2], chat_history)
|
| 344 |
+
|
| 345 |
+
@patch("app.PipelineTimer")
|
| 346 |
+
@patch("app.generate_xeno_response")
|
| 347 |
+
@patch("app.process_context")
|
| 348 |
+
@patch("app.generate_embeddings")
|
| 349 |
+
@patch("app.log_timing_data")
|
| 350 |
+
@patch("app.log_response")
|
| 351 |
+
@patch("app.update_memory")
|
| 352 |
+
@patch("app.retrieve_memory")
|
| 353 |
+
@patch("app.create_session_config")
|
| 354 |
+
def test_get_context_and_answer_timing(
|
| 355 |
+
self,
|
| 356 |
+
mock_session_config,
|
| 357 |
+
mock_retrieve_memory,
|
| 358 |
+
mock_update_memory,
|
| 359 |
+
mock_log_response,
|
| 360 |
+
mock_log_timing,
|
| 361 |
+
mock_generate_embeddings,
|
| 362 |
+
mock_process_context,
|
| 363 |
+
mock_generate_response,
|
| 364 |
+
mock_timer_class,
|
| 365 |
+
):
|
| 366 |
+
"""Test get_context_and_answer uses PipelineTimer correctly"""
|
| 367 |
+
# Setup mocks
|
| 368 |
+
mock_timer = Mock()
|
| 369 |
+
mock_timer.time_step = MagicMock()
|
| 370 |
+
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 371 |
+
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 372 |
+
mock_timer.get_timing_summary.return_value = {"total": 1.5}
|
| 373 |
+
mock_timer_class.return_value = mock_timer
|
| 374 |
+
|
| 375 |
+
mock_session_config.return_value = {"session_id": self.session_id}
|
| 376 |
+
mock_retrieve_memory.return_value = []
|
| 377 |
+
self.mock_intent_classifier.classify_intent.return_value = ("query", None)
|
| 378 |
+
|
| 379 |
+
# Mock retriever
|
| 380 |
+
mock_doc = Mock()
|
| 381 |
+
mock_doc.page_content = "Test"
|
| 382 |
+
mock_doc.metadata = {"id": "KB001", "question": "Q", "content": "A"}
|
| 383 |
+
self.mock_retriever.invoke.return_value = [mock_doc]
|
| 384 |
+
|
| 385 |
+
# Mock embeddings
|
| 386 |
+
mock_generate_embeddings.return_value = ([0.1], [[0.9]])
|
| 387 |
+
mock_process_context.return_value = ("Context", ["KB001"], [("Q", "A")])
|
| 388 |
+
mock_generate_response.return_value = "Answer"
|
| 389 |
+
|
| 390 |
+
# Call function
|
| 391 |
+
get_context_and_answer(
|
| 392 |
+
self.message,
|
| 393 |
+
self.history,
|
| 394 |
+
self.session_id,
|
| 395 |
+
self.mock_intent_classifier,
|
| 396 |
+
self.mock_retriever,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# Verify timer was used
|
| 400 |
+
mock_timer.reset.assert_called_once()
|
| 401 |
+
mock_timer.get_timing_summary.assert_called()
|
| 402 |
+
|
| 403 |
+
# Verify timing was logged
|
| 404 |
+
mock_log_timing.assert_called_once()
|
| 405 |
+
call_args = mock_log_timing.call_args[0]
|
| 406 |
+
# Second positional argument is session_id, third is timing_summary
|
| 407 |
+
self.assertIn("total", call_args[2])
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
if __name__ == "__main__":
|
| 411 |
+
unittest.main()
|
tests/test_intent_classifier.py
CHANGED
|
@@ -2,25 +2,27 @@
|
|
| 2 |
Unit tests for intent_classifier module
|
| 3 |
Tests the IntentClassifier class
|
| 4 |
"""
|
|
|
|
| 5 |
import unittest
|
| 6 |
from unittest.mock import Mock
|
|
|
|
| 7 |
from src.intent_classifier import IntentClassifier
|
| 8 |
|
| 9 |
|
| 10 |
class TestIntentClassifier(unittest.TestCase):
|
| 11 |
"""Test cases for IntentClassifier class"""
|
| 12 |
-
|
| 13 |
def setUp(self):
|
| 14 |
"""Set up test fixtures"""
|
| 15 |
self.classifier = IntentClassifier()
|
| 16 |
-
|
| 17 |
def test_initialization(self):
|
| 18 |
"""Test classifier initialization"""
|
| 19 |
self.assertIsNotNone(self.classifier.intent_patterns)
|
| 20 |
-
self.assertIn(
|
| 21 |
-
self.assertIn(
|
| 22 |
-
self.assertIn(
|
| 23 |
-
|
| 24 |
def test_classify_greeting(self):
|
| 25 |
"""Test classification of greeting messages"""
|
| 26 |
test_cases = [
|
|
@@ -29,15 +31,15 @@ class TestIntentClassifier(unittest.TestCase):
|
|
| 29 |
"Hey there",
|
| 30 |
"good morning",
|
| 31 |
"Good afternoon!",
|
| 32 |
-
"how are you"
|
| 33 |
]
|
| 34 |
-
|
| 35 |
for message in test_cases:
|
| 36 |
intent, response = self.classifier.classify_intent(message)
|
| 37 |
-
self.assertEqual(intent,
|
| 38 |
self.assertIsInstance(response, str)
|
| 39 |
self.assertGreater(len(response), 0)
|
| 40 |
-
|
| 41 |
def test_classify_thanks(self):
|
| 42 |
"""Test classification of thank you messages"""
|
| 43 |
test_cases = [
|
|
@@ -47,15 +49,15 @@ class TestIntentClassifier(unittest.TestCase):
|
|
| 47 |
"thx",
|
| 48 |
"I appreciate it",
|
| 49 |
"thanks a lot",
|
| 50 |
-
"thank you so much"
|
| 51 |
]
|
| 52 |
-
|
| 53 |
for message in test_cases:
|
| 54 |
intent, response = self.classifier.classify_intent(message)
|
| 55 |
-
self.assertEqual(intent,
|
| 56 |
self.assertIsInstance(response, str)
|
| 57 |
self.assertGreater(len(response), 0)
|
| 58 |
-
|
| 59 |
def test_classify_goodbye(self):
|
| 60 |
"""Test classification of goodbye messages"""
|
| 61 |
test_cases = [
|
|
@@ -65,78 +67,82 @@ class TestIntentClassifier(unittest.TestCase):
|
|
| 65 |
"farewell",
|
| 66 |
"take care",
|
| 67 |
"have a good day",
|
| 68 |
-
"talk to you later"
|
| 69 |
]
|
| 70 |
-
|
| 71 |
for message in test_cases:
|
| 72 |
intent, response = self.classifier.classify_intent(message)
|
| 73 |
-
self.assertEqual(intent,
|
| 74 |
self.assertIsInstance(response, str)
|
| 75 |
self.assertGreater(len(response), 0)
|
| 76 |
-
|
| 77 |
def test_classify_query(self):
|
| 78 |
"""Test classification of query messages"""
|
| 79 |
test_cases = [
|
| 80 |
"How do I open an account?",
|
| 81 |
"What are the transaction fees?",
|
| 82 |
"Can you help me with my balance?",
|
| 83 |
-
"Tell me about XENO services"
|
| 84 |
]
|
| 85 |
-
|
| 86 |
for message in test_cases:
|
| 87 |
intent, response = self.classifier.classify_intent(message)
|
| 88 |
-
self.assertEqual(intent,
|
| 89 |
-
self.assertEqual(response,
|
| 90 |
-
|
| 91 |
def test_case_insensitivity(self):
|
| 92 |
"""Test that classification is case insensitive"""
|
| 93 |
messages = [
|
| 94 |
-
("HI",
|
| 95 |
-
("THANK YOU",
|
| 96 |
-
("BYE",
|
| 97 |
-
("Hi There",
|
| 98 |
]
|
| 99 |
-
|
| 100 |
for message, expected_intent in messages:
|
| 101 |
intent, _ = self.classifier.classify_intent(message)
|
| 102 |
self.assertEqual(intent, expected_intent)
|
| 103 |
-
|
| 104 |
def test_with_timer(self):
|
| 105 |
"""Test classification with timer object"""
|
| 106 |
mock_timer = Mock()
|
| 107 |
mock_timer.time_step = Mock()
|
| 108 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 109 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 110 |
-
|
| 111 |
intent, response = self.classifier.classify_intent("hello", timer=mock_timer)
|
| 112 |
-
|
| 113 |
-
self.assertEqual(intent,
|
| 114 |
mock_timer.time_step.assert_called_once_with("intent_classification")
|
| 115 |
-
|
| 116 |
def test_is_simple_intent(self):
|
| 117 |
"""Test is_simple_intent method"""
|
| 118 |
-
self.assertTrue(self.classifier.is_simple_intent(
|
| 119 |
-
self.assertTrue(self.classifier.is_simple_intent(
|
| 120 |
-
self.assertFalse(self.classifier.is_simple_intent(
|
| 121 |
-
self.assertFalse(self.classifier.is_simple_intent(
|
| 122 |
-
|
| 123 |
def test_add_intent(self):
|
| 124 |
"""Test adding a new intent"""
|
| 125 |
-
patterns = [r
|
| 126 |
responses = ["This is a test response"]
|
| 127 |
-
|
| 128 |
-
self.classifier.add_intent(
|
| 129 |
-
|
| 130 |
# Verify intent was added
|
| 131 |
-
self.assertIn(
|
| 132 |
-
self.assertEqual(
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
# Test classification with new intent
|
| 136 |
intent, response = self.classifier.classify_intent("testing")
|
| 137 |
-
self.assertEqual(intent,
|
| 138 |
self.assertEqual(response, "This is a test response")
|
| 139 |
-
|
| 140 |
def test_response_variety(self):
|
| 141 |
"""Test that responses vary (random selection)"""
|
| 142 |
# Multiple calls might return different responses
|
|
@@ -144,26 +150,26 @@ class TestIntentClassifier(unittest.TestCase):
|
|
| 144 |
for _ in range(20):
|
| 145 |
_, response = self.classifier.classify_intent("hello")
|
| 146 |
responses.add(response)
|
| 147 |
-
|
| 148 |
# Should have at least 1 response (could be more if random varies)
|
| 149 |
self.assertGreater(len(responses), 0)
|
| 150 |
-
|
| 151 |
def test_empty_message(self):
|
| 152 |
"""Test classification of empty or whitespace messages"""
|
| 153 |
test_cases = ["", " ", "\n", "\t"]
|
| 154 |
-
|
| 155 |
for message in test_cases:
|
| 156 |
intent, response = self.classifier.classify_intent(message)
|
| 157 |
-
self.assertEqual(intent,
|
| 158 |
-
self.assertEqual(response,
|
| 159 |
-
|
| 160 |
def test_mixed_intent_message(self):
|
| 161 |
"""Test messages that might match multiple patterns"""
|
| 162 |
# "hi thank you" should match greeting (first match wins)
|
| 163 |
intent, response = self.classifier.classify_intent("hi thank you")
|
| 164 |
# Should match the first pattern it encounters
|
| 165 |
-
self.assertIn(intent, [
|
| 166 |
|
| 167 |
|
| 168 |
-
if __name__ ==
|
| 169 |
unittest.main()
|
|
|
|
| 2 |
Unit tests for intent_classifier module
|
| 3 |
Tests the IntentClassifier class
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import unittest
|
| 7 |
from unittest.mock import Mock
|
| 8 |
+
|
| 9 |
from src.intent_classifier import IntentClassifier
|
| 10 |
|
| 11 |
|
| 12 |
class TestIntentClassifier(unittest.TestCase):
|
| 13 |
"""Test cases for IntentClassifier class"""
|
| 14 |
+
|
| 15 |
def setUp(self):
|
| 16 |
"""Set up test fixtures"""
|
| 17 |
self.classifier = IntentClassifier()
|
| 18 |
+
|
| 19 |
def test_initialization(self):
|
| 20 |
"""Test classifier initialization"""
|
| 21 |
self.assertIsNotNone(self.classifier.intent_patterns)
|
| 22 |
+
self.assertIn("greeting", self.classifier.intent_patterns)
|
| 23 |
+
self.assertIn("thanks", self.classifier.intent_patterns)
|
| 24 |
+
self.assertIn("goodbye", self.classifier.intent_patterns)
|
| 25 |
+
|
| 26 |
def test_classify_greeting(self):
|
| 27 |
"""Test classification of greeting messages"""
|
| 28 |
test_cases = [
|
|
|
|
| 31 |
"Hey there",
|
| 32 |
"good morning",
|
| 33 |
"Good afternoon!",
|
| 34 |
+
"how are you",
|
| 35 |
]
|
| 36 |
+
|
| 37 |
for message in test_cases:
|
| 38 |
intent, response = self.classifier.classify_intent(message)
|
| 39 |
+
self.assertEqual(intent, "greeting", f"Failed for message: {message}")
|
| 40 |
self.assertIsInstance(response, str)
|
| 41 |
self.assertGreater(len(response), 0)
|
| 42 |
+
|
| 43 |
def test_classify_thanks(self):
|
| 44 |
"""Test classification of thank you messages"""
|
| 45 |
test_cases = [
|
|
|
|
| 49 |
"thx",
|
| 50 |
"I appreciate it",
|
| 51 |
"thanks a lot",
|
| 52 |
+
"thank you so much",
|
| 53 |
]
|
| 54 |
+
|
| 55 |
for message in test_cases:
|
| 56 |
intent, response = self.classifier.classify_intent(message)
|
| 57 |
+
self.assertEqual(intent, "thanks", f"Failed for message: {message}")
|
| 58 |
self.assertIsInstance(response, str)
|
| 59 |
self.assertGreater(len(response), 0)
|
| 60 |
+
|
| 61 |
def test_classify_goodbye(self):
|
| 62 |
"""Test classification of goodbye messages"""
|
| 63 |
test_cases = [
|
|
|
|
| 67 |
"farewell",
|
| 68 |
"take care",
|
| 69 |
"have a good day",
|
| 70 |
+
"talk to you later",
|
| 71 |
]
|
| 72 |
+
|
| 73 |
for message in test_cases:
|
| 74 |
intent, response = self.classifier.classify_intent(message)
|
| 75 |
+
self.assertEqual(intent, "goodbye", f"Failed for message: {message}")
|
| 76 |
self.assertIsInstance(response, str)
|
| 77 |
self.assertGreater(len(response), 0)
|
| 78 |
+
|
| 79 |
def test_classify_query(self):
|
| 80 |
"""Test classification of query messages"""
|
| 81 |
test_cases = [
|
| 82 |
"How do I open an account?",
|
| 83 |
"What are the transaction fees?",
|
| 84 |
"Can you help me with my balance?",
|
| 85 |
+
"Tell me about XENO services",
|
| 86 |
]
|
| 87 |
+
|
| 88 |
for message in test_cases:
|
| 89 |
intent, response = self.classifier.classify_intent(message)
|
| 90 |
+
self.assertEqual(intent, "query", f"Failed for message: {message}")
|
| 91 |
+
self.assertEqual(response, "")
|
| 92 |
+
|
| 93 |
def test_case_insensitivity(self):
|
| 94 |
"""Test that classification is case insensitive"""
|
| 95 |
messages = [
|
| 96 |
+
("HI", "greeting"),
|
| 97 |
+
("THANK YOU", "thanks"),
|
| 98 |
+
("BYE", "goodbye"),
|
| 99 |
+
("Hi There", "greeting"),
|
| 100 |
]
|
| 101 |
+
|
| 102 |
for message, expected_intent in messages:
|
| 103 |
intent, _ = self.classifier.classify_intent(message)
|
| 104 |
self.assertEqual(intent, expected_intent)
|
| 105 |
+
|
| 106 |
def test_with_timer(self):
|
| 107 |
"""Test classification with timer object"""
|
| 108 |
mock_timer = Mock()
|
| 109 |
mock_timer.time_step = Mock()
|
| 110 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 111 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 112 |
+
|
| 113 |
intent, response = self.classifier.classify_intent("hello", timer=mock_timer)
|
| 114 |
+
|
| 115 |
+
self.assertEqual(intent, "greeting")
|
| 116 |
mock_timer.time_step.assert_called_once_with("intent_classification")
|
| 117 |
+
|
| 118 |
def test_is_simple_intent(self):
|
| 119 |
"""Test is_simple_intent method"""
|
| 120 |
+
self.assertTrue(self.classifier.is_simple_intent("greeting"))
|
| 121 |
+
self.assertTrue(self.classifier.is_simple_intent("thanks"))
|
| 122 |
+
self.assertFalse(self.classifier.is_simple_intent("goodbye"))
|
| 123 |
+
self.assertFalse(self.classifier.is_simple_intent("query"))
|
| 124 |
+
|
| 125 |
def test_add_intent(self):
|
| 126 |
"""Test adding a new intent"""
|
| 127 |
+
patterns = [r"\b(test|testing)\b"]
|
| 128 |
responses = ["This is a test response"]
|
| 129 |
+
|
| 130 |
+
self.classifier.add_intent("test_intent", patterns, responses)
|
| 131 |
+
|
| 132 |
# Verify intent was added
|
| 133 |
+
self.assertIn("test_intent", self.classifier.intent_patterns)
|
| 134 |
+
self.assertEqual(
|
| 135 |
+
self.classifier.intent_patterns["test_intent"]["patterns"], patterns
|
| 136 |
+
)
|
| 137 |
+
self.assertEqual(
|
| 138 |
+
self.classifier.intent_patterns["test_intent"]["responses"], responses
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
# Test classification with new intent
|
| 142 |
intent, response = self.classifier.classify_intent("testing")
|
| 143 |
+
self.assertEqual(intent, "test_intent")
|
| 144 |
self.assertEqual(response, "This is a test response")
|
| 145 |
+
|
| 146 |
def test_response_variety(self):
|
| 147 |
"""Test that responses vary (random selection)"""
|
| 148 |
# Multiple calls might return different responses
|
|
|
|
| 150 |
for _ in range(20):
|
| 151 |
_, response = self.classifier.classify_intent("hello")
|
| 152 |
responses.add(response)
|
| 153 |
+
|
| 154 |
# Should have at least 1 response (could be more if random varies)
|
| 155 |
self.assertGreater(len(responses), 0)
|
| 156 |
+
|
| 157 |
def test_empty_message(self):
|
| 158 |
"""Test classification of empty or whitespace messages"""
|
| 159 |
test_cases = ["", " ", "\n", "\t"]
|
| 160 |
+
|
| 161 |
for message in test_cases:
|
| 162 |
intent, response = self.classifier.classify_intent(message)
|
| 163 |
+
self.assertEqual(intent, "query")
|
| 164 |
+
self.assertEqual(response, "")
|
| 165 |
+
|
| 166 |
def test_mixed_intent_message(self):
|
| 167 |
"""Test messages that might match multiple patterns"""
|
| 168 |
# "hi thank you" should match greeting (first match wins)
|
| 169 |
intent, response = self.classifier.classify_intent("hi thank you")
|
| 170 |
# Should match the first pattern it encounters
|
| 171 |
+
self.assertIn(intent, ["greeting", "thanks"])
|
| 172 |
|
| 173 |
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
unittest.main()
|
tests/test_interface.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for interface module
|
| 3 |
+
Tests Gradio interface functionality
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import unittest
|
| 7 |
+
import uuid
|
| 8 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 9 |
+
|
| 10 |
+
from src.interface import create_interface, respond
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestInterface(unittest.TestCase):
|
| 14 |
+
"""Test cases for interface module"""
|
| 15 |
+
|
| 16 |
+
def setUp(self):
|
| 17 |
+
"""Set up test fixtures"""
|
| 18 |
+
self.message = "How do I create an account?"
|
| 19 |
+
self.history = [["Previous question", "Previous answer"]]
|
| 20 |
+
self.session_id = str(uuid.uuid4())
|
| 21 |
+
self.mock_intent_classifier = Mock()
|
| 22 |
+
self.mock_retriever = Mock()
|
| 23 |
+
|
| 24 |
+
@patch("app.get_context_and_answer")
|
| 25 |
+
def test_respond_with_session_id(self, mock_get_answer):
|
| 26 |
+
"""Test respond function with existing session ID"""
|
| 27 |
+
mock_get_answer.return_value = "You can create an account by visiting our website."
|
| 28 |
+
|
| 29 |
+
result_msg, result_history = respond(
|
| 30 |
+
self.message,
|
| 31 |
+
self.history.copy(),
|
| 32 |
+
self.session_id,
|
| 33 |
+
self.mock_intent_classifier,
|
| 34 |
+
self.mock_retriever,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Verify get_context_and_answer was called
|
| 38 |
+
mock_get_answer.assert_called_once()
|
| 39 |
+
call_args = mock_get_answer.call_args[0]
|
| 40 |
+
self.assertEqual(call_args[0], self.message)
|
| 41 |
+
self.assertEqual(call_args[2], self.session_id)
|
| 42 |
+
|
| 43 |
+
# Check return values
|
| 44 |
+
self.assertEqual(result_msg, "")
|
| 45 |
+
self.assertEqual(len(result_history), 2)
|
| 46 |
+
self.assertEqual(result_history[-1][0], self.message)
|
| 47 |
+
self.assertEqual(
|
| 48 |
+
result_history[-1][1],
|
| 49 |
+
"You can create an account by visiting our website.",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
@patch("app.get_context_and_answer")
|
| 53 |
+
def test_respond_without_session_id(self, mock_get_answer):
|
| 54 |
+
"""Test respond function generates session ID when none provided"""
|
| 55 |
+
mock_get_answer.return_value = "Response"
|
| 56 |
+
|
| 57 |
+
result_msg, result_history = respond(
|
| 58 |
+
self.message,
|
| 59 |
+
[],
|
| 60 |
+
None,
|
| 61 |
+
self.mock_intent_classifier,
|
| 62 |
+
self.mock_retriever,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Should have called with a generated session ID
|
| 66 |
+
self.assertEqual(mock_get_answer.call_count, 1)
|
| 67 |
+
call_args = mock_get_answer.call_args[0]
|
| 68 |
+
generated_session_id = call_args[2]
|
| 69 |
+
|
| 70 |
+
# Verify it's a valid UUID
|
| 71 |
+
try:
|
| 72 |
+
uuid.UUID(generated_session_id)
|
| 73 |
+
valid_uuid = True
|
| 74 |
+
except ValueError:
|
| 75 |
+
valid_uuid = False
|
| 76 |
+
|
| 77 |
+
self.assertTrue(valid_uuid)
|
| 78 |
+
|
| 79 |
+
# Check return values
|
| 80 |
+
self.assertEqual(result_msg, "")
|
| 81 |
+
self.assertEqual(len(result_history), 1)
|
| 82 |
+
|
| 83 |
+
@patch("app.get_context_and_answer")
|
| 84 |
+
def test_respond_with_empty_history(self, mock_get_answer):
|
| 85 |
+
"""Test respond function with empty history"""
|
| 86 |
+
mock_get_answer.return_value = "Test response"
|
| 87 |
+
|
| 88 |
+
result_msg, result_history = respond(
|
| 89 |
+
"Test question",
|
| 90 |
+
[],
|
| 91 |
+
self.session_id,
|
| 92 |
+
self.mock_intent_classifier,
|
| 93 |
+
self.mock_retriever,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# History should have one entry
|
| 97 |
+
self.assertEqual(len(result_history), 1)
|
| 98 |
+
self.assertEqual(result_history[0][0], "Test question")
|
| 99 |
+
self.assertEqual(result_history[0][1], "Test response")
|
| 100 |
+
|
| 101 |
+
@patch("app.get_context_and_answer")
|
| 102 |
+
def test_respond_preserves_existing_history(self, mock_get_answer):
|
| 103 |
+
"""Test respond function preserves existing chat history"""
|
| 104 |
+
mock_get_answer.return_value = "New response"
|
| 105 |
+
|
| 106 |
+
initial_history = [
|
| 107 |
+
["Question 1", "Answer 1"],
|
| 108 |
+
["Question 2", "Answer 2"],
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
result_msg, result_history = respond(
|
| 112 |
+
"Question 3",
|
| 113 |
+
initial_history.copy(),
|
| 114 |
+
self.session_id,
|
| 115 |
+
self.mock_intent_classifier,
|
| 116 |
+
self.mock_retriever,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Should have 3 entries now
|
| 120 |
+
self.assertEqual(len(result_history), 3)
|
| 121 |
+
self.assertEqual(result_history[0][0], "Question 1")
|
| 122 |
+
self.assertEqual(result_history[1][0], "Question 2")
|
| 123 |
+
self.assertEqual(result_history[2][0], "Question 3")
|
| 124 |
+
|
| 125 |
+
def test_create_interface_returns_blocks(self):
|
| 126 |
+
"""Test create_interface returns Gradio Blocks interface"""
|
| 127 |
+
result = create_interface(self.mock_intent_classifier, self.mock_retriever)
|
| 128 |
+
|
| 129 |
+
# Should return a Gradio Blocks object
|
| 130 |
+
import gradio as gr
|
| 131 |
+
self.assertIsInstance(result, gr.Blocks)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
unittest.main()
|
tests/test_knowledge_base.py
CHANGED
|
@@ -2,22 +2,22 @@
|
|
| 2 |
Unit tests for knowledge_base module
|
| 3 |
Tests knowledge base loading and preparation
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
-
import pandas as pd
|
| 7 |
import json
|
| 8 |
-
import tempfile
|
| 9 |
import os
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class TestKnowledgeBase(unittest.TestCase):
|
| 19 |
"""Test cases for knowledge_base module"""
|
| 20 |
-
|
| 21 |
def setUp(self):
|
| 22 |
"""Set up test fixtures"""
|
| 23 |
# Create sample knowledge base data
|
|
@@ -29,7 +29,7 @@ class TestKnowledgeBase(unittest.TestCase):
|
|
| 29 |
"Section": "Account Management",
|
| 30 |
"Source": "Website",
|
| 31 |
"Owner": "Support Team",
|
| 32 |
-
"Tag": "account"
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"ID": "KB002",
|
|
@@ -38,35 +38,33 @@ class TestKnowledgeBase(unittest.TestCase):
|
|
| 38 |
"Section": "Fees",
|
| 39 |
"Source": "Documentation",
|
| 40 |
"Owner": "Finance Team",
|
| 41 |
-
"Tag": "fees"
|
| 42 |
-
}
|
| 43 |
]
|
| 44 |
-
|
| 45 |
# Create temporary JSON file
|
| 46 |
self.temp_file = tempfile.NamedTemporaryFile(
|
| 47 |
-
mode=
|
| 48 |
-
delete=False,
|
| 49 |
-
suffix='.json'
|
| 50 |
)
|
| 51 |
json.dump(self.sample_data, self.temp_file)
|
| 52 |
self.temp_file.close()
|
| 53 |
-
|
| 54 |
def tearDown(self):
|
| 55 |
"""Clean up test fixtures"""
|
| 56 |
if os.path.exists(self.temp_file.name):
|
| 57 |
os.unlink(self.temp_file.name)
|
| 58 |
-
|
| 59 |
def test_load_knowledge_base(self):
|
| 60 |
"""Test loading knowledge base from JSON file"""
|
| 61 |
df = load_knowledge_base(self.temp_file.name)
|
| 62 |
-
|
| 63 |
# Check DataFrame structure
|
| 64 |
self.assertIsInstance(df, pd.DataFrame)
|
| 65 |
self.assertEqual(len(df), 2)
|
| 66 |
-
self.assertIn(
|
| 67 |
-
self.assertIn(
|
| 68 |
-
self.assertIn(
|
| 69 |
-
|
| 70 |
def test_load_knowledge_base_drops_null_content(self):
|
| 71 |
"""Test that rows with null Content are dropped"""
|
| 72 |
data_with_null = self.sample_data + [
|
|
@@ -74,110 +72,124 @@ class TestKnowledgeBase(unittest.TestCase):
|
|
| 74 |
"ID": "KB003",
|
| 75 |
"Question": "Test question?",
|
| 76 |
"Content": None,
|
| 77 |
-
"Section": "Test"
|
| 78 |
}
|
| 79 |
]
|
| 80 |
-
|
| 81 |
temp_file_null = tempfile.NamedTemporaryFile(
|
| 82 |
-
mode=
|
| 83 |
-
delete=False,
|
| 84 |
-
suffix='.json'
|
| 85 |
)
|
| 86 |
json.dump(data_with_null, temp_file_null)
|
| 87 |
temp_file_null.close()
|
| 88 |
-
|
| 89 |
try:
|
| 90 |
df = load_knowledge_base(temp_file_null.name)
|
| 91 |
# Should only have 2 rows (null Content row dropped)
|
| 92 |
self.assertEqual(len(df), 2)
|
| 93 |
finally:
|
| 94 |
os.unlink(temp_file_null.name)
|
| 95 |
-
|
| 96 |
def test_prepare_documents(self):
|
| 97 |
"""Test preparing documents for vector store"""
|
| 98 |
documents, metadatas, ids = prepare_documents(self.sample_data)
|
| 99 |
-
|
| 100 |
# Check lengths match
|
| 101 |
self.assertEqual(len(documents), 2)
|
| 102 |
self.assertEqual(len(metadatas), 2)
|
| 103 |
self.assertEqual(len(ids), 2)
|
| 104 |
-
|
| 105 |
# Check document format
|
| 106 |
self.assertIn("Question:", documents[0])
|
| 107 |
self.assertIn("Answer:", documents[0])
|
| 108 |
self.assertIn("How do I create an account?", documents[0])
|
| 109 |
-
|
| 110 |
# Check metadata structure
|
| 111 |
-
self.assertEqual(metadatas[0][
|
| 112 |
-
self.assertEqual(metadatas[0][
|
| 113 |
-
self.assertEqual(metadatas[0][
|
| 114 |
-
|
| 115 |
# Check IDs
|
| 116 |
-
self.assertEqual(ids[0],
|
| 117 |
-
self.assertEqual(ids[1],
|
| 118 |
-
|
| 119 |
def test_prepare_documents_with_missing_fields(self):
|
| 120 |
"""Test preparing documents with missing optional fields"""
|
| 121 |
data_minimal = [
|
| 122 |
-
{
|
| 123 |
-
"ID": "KB001",
|
| 124 |
-
"Question": "Test question?",
|
| 125 |
-
"Content": "Test answer."
|
| 126 |
-
}
|
| 127 |
]
|
| 128 |
-
|
| 129 |
documents, metadatas, ids = prepare_documents(data_minimal)
|
| 130 |
-
|
| 131 |
# Should still work with defaults
|
| 132 |
self.assertEqual(len(documents), 1)
|
| 133 |
-
self.assertEqual(metadatas[0][
|
| 134 |
-
self.assertEqual(metadatas[0][
|
| 135 |
-
self.assertEqual(metadatas[0][
|
| 136 |
-
self.assertEqual(metadatas[0][
|
| 137 |
-
|
| 138 |
-
@patch(
|
| 139 |
def test_get_knowledge_base_data(self, mock_load):
|
| 140 |
"""Test get_knowledge_base_data function"""
|
| 141 |
# Mock the load_knowledge_base function
|
| 142 |
mock_df = pd.DataFrame(self.sample_data)
|
| 143 |
mock_load.return_value = mock_df
|
| 144 |
-
|
| 145 |
documents, metadatas, ids = get_knowledge_base_data()
|
| 146 |
-
|
| 147 |
# Verify load was called
|
| 148 |
mock_load.assert_called_once()
|
| 149 |
-
|
| 150 |
# Verify output
|
| 151 |
self.assertEqual(len(documents), 2)
|
| 152 |
self.assertEqual(len(metadatas), 2)
|
| 153 |
self.assertEqual(len(ids), 2)
|
| 154 |
-
|
| 155 |
def test_document_text_format(self):
|
| 156 |
"""Test that document text is properly formatted"""
|
| 157 |
documents, _, _ = prepare_documents(self.sample_data)
|
| 158 |
-
|
| 159 |
# Check first document format
|
| 160 |
expected_format = "Question: How do I create an account?\nAnswer: You can create an account by visiting our website."
|
| 161 |
self.assertEqual(documents[0], expected_format)
|
| 162 |
-
|
| 163 |
def test_empty_knowledge_base(self):
|
| 164 |
"""Test handling of empty knowledge base"""
|
| 165 |
empty_data = []
|
| 166 |
documents, metadatas, ids = prepare_documents(empty_data)
|
| 167 |
-
|
| 168 |
self.assertEqual(len(documents), 0)
|
| 169 |
self.assertEqual(len(metadatas), 0)
|
| 170 |
self.assertEqual(len(ids), 0)
|
| 171 |
-
|
| 172 |
def test_metadata_completeness(self):
|
| 173 |
"""Test that all metadata fields are present"""
|
| 174 |
_, metadatas, _ = prepare_documents(self.sample_data)
|
| 175 |
-
|
| 176 |
-
required_fields = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
for metadata in metadatas:
|
| 178 |
for field in required_fields:
|
| 179 |
self.assertIn(field, metadata)
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
if __name__ ==
|
| 183 |
unittest.main()
|
|
|
|
| 2 |
Unit tests for knowledge_base module
|
| 3 |
Tests knowledge base loading and preparation
|
| 4 |
"""
|
| 5 |
+
|
|
|
|
| 6 |
import json
|
|
|
|
| 7 |
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
import unittest
|
| 10 |
+
from unittest.mock import patch
|
| 11 |
+
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
from src.knowledge_base import (get_knowledge_base_data, load_knowledge_base,
|
| 15 |
+
prepare_documents)
|
| 16 |
|
| 17 |
|
| 18 |
class TestKnowledgeBase(unittest.TestCase):
|
| 19 |
"""Test cases for knowledge_base module"""
|
| 20 |
+
|
| 21 |
def setUp(self):
|
| 22 |
"""Set up test fixtures"""
|
| 23 |
# Create sample knowledge base data
|
|
|
|
| 29 |
"Section": "Account Management",
|
| 30 |
"Source": "Website",
|
| 31 |
"Owner": "Support Team",
|
| 32 |
+
"Tag": "account",
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"ID": "KB002",
|
|
|
|
| 38 |
"Section": "Fees",
|
| 39 |
"Source": "Documentation",
|
| 40 |
"Owner": "Finance Team",
|
| 41 |
+
"Tag": "fees",
|
| 42 |
+
},
|
| 43 |
]
|
| 44 |
+
|
| 45 |
# Create temporary JSON file
|
| 46 |
self.temp_file = tempfile.NamedTemporaryFile(
|
| 47 |
+
mode="w", delete=False, suffix=".json"
|
|
|
|
|
|
|
| 48 |
)
|
| 49 |
json.dump(self.sample_data, self.temp_file)
|
| 50 |
self.temp_file.close()
|
| 51 |
+
|
| 52 |
def tearDown(self):
|
| 53 |
"""Clean up test fixtures"""
|
| 54 |
if os.path.exists(self.temp_file.name):
|
| 55 |
os.unlink(self.temp_file.name)
|
| 56 |
+
|
| 57 |
def test_load_knowledge_base(self):
|
| 58 |
"""Test loading knowledge base from JSON file"""
|
| 59 |
df = load_knowledge_base(self.temp_file.name)
|
| 60 |
+
|
| 61 |
# Check DataFrame structure
|
| 62 |
self.assertIsInstance(df, pd.DataFrame)
|
| 63 |
self.assertEqual(len(df), 2)
|
| 64 |
+
self.assertIn("ID", df.columns)
|
| 65 |
+
self.assertIn("Question", df.columns)
|
| 66 |
+
self.assertIn("Content", df.columns)
|
| 67 |
+
|
| 68 |
def test_load_knowledge_base_drops_null_content(self):
|
| 69 |
"""Test that rows with null Content are dropped"""
|
| 70 |
data_with_null = self.sample_data + [
|
|
|
|
| 72 |
"ID": "KB003",
|
| 73 |
"Question": "Test question?",
|
| 74 |
"Content": None,
|
| 75 |
+
"Section": "Test",
|
| 76 |
}
|
| 77 |
]
|
| 78 |
+
|
| 79 |
temp_file_null = tempfile.NamedTemporaryFile(
|
| 80 |
+
mode="w", delete=False, suffix=".json"
|
|
|
|
|
|
|
| 81 |
)
|
| 82 |
json.dump(data_with_null, temp_file_null)
|
| 83 |
temp_file_null.close()
|
| 84 |
+
|
| 85 |
try:
|
| 86 |
df = load_knowledge_base(temp_file_null.name)
|
| 87 |
# Should only have 2 rows (null Content row dropped)
|
| 88 |
self.assertEqual(len(df), 2)
|
| 89 |
finally:
|
| 90 |
os.unlink(temp_file_null.name)
|
| 91 |
+
|
| 92 |
def test_prepare_documents(self):
|
| 93 |
"""Test preparing documents for vector store"""
|
| 94 |
documents, metadatas, ids = prepare_documents(self.sample_data)
|
| 95 |
+
|
| 96 |
# Check lengths match
|
| 97 |
self.assertEqual(len(documents), 2)
|
| 98 |
self.assertEqual(len(metadatas), 2)
|
| 99 |
self.assertEqual(len(ids), 2)
|
| 100 |
+
|
| 101 |
# Check document format
|
| 102 |
self.assertIn("Question:", documents[0])
|
| 103 |
self.assertIn("Answer:", documents[0])
|
| 104 |
self.assertIn("How do I create an account?", documents[0])
|
| 105 |
+
|
| 106 |
# Check metadata structure
|
| 107 |
+
self.assertEqual(metadatas[0]["id"], "KB001")
|
| 108 |
+
self.assertEqual(metadatas[0]["question"], "How do I create an account?")
|
| 109 |
+
self.assertEqual(metadatas[0]["section"], "Account Management")
|
| 110 |
+
|
| 111 |
# Check IDs
|
| 112 |
+
self.assertEqual(ids[0], "KB001")
|
| 113 |
+
self.assertEqual(ids[1], "KB002")
|
| 114 |
+
|
| 115 |
def test_prepare_documents_with_missing_fields(self):
|
| 116 |
"""Test preparing documents with missing optional fields"""
|
| 117 |
data_minimal = [
|
| 118 |
+
{"ID": "KB001", "Question": "Test question?", "Content": "Test answer."}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
]
|
| 120 |
+
|
| 121 |
documents, metadatas, ids = prepare_documents(data_minimal)
|
| 122 |
+
|
| 123 |
# Should still work with defaults
|
| 124 |
self.assertEqual(len(documents), 1)
|
| 125 |
+
self.assertEqual(metadatas[0]["section"], "")
|
| 126 |
+
self.assertEqual(metadatas[0]["source"], "")
|
| 127 |
+
self.assertEqual(metadatas[0]["owner"], "")
|
| 128 |
+
self.assertEqual(metadatas[0]["tag"], "")
|
| 129 |
+
|
| 130 |
+
@patch("src.knowledge_base.load_knowledge_base")
|
| 131 |
def test_get_knowledge_base_data(self, mock_load):
|
| 132 |
"""Test get_knowledge_base_data function"""
|
| 133 |
# Mock the load_knowledge_base function
|
| 134 |
mock_df = pd.DataFrame(self.sample_data)
|
| 135 |
mock_load.return_value = mock_df
|
| 136 |
+
|
| 137 |
documents, metadatas, ids = get_knowledge_base_data()
|
| 138 |
+
|
| 139 |
# Verify load was called
|
| 140 |
mock_load.assert_called_once()
|
| 141 |
+
|
| 142 |
# Verify output
|
| 143 |
self.assertEqual(len(documents), 2)
|
| 144 |
self.assertEqual(len(metadatas), 2)
|
| 145 |
self.assertEqual(len(ids), 2)
|
| 146 |
+
|
| 147 |
def test_document_text_format(self):
|
| 148 |
"""Test that document text is properly formatted"""
|
| 149 |
documents, _, _ = prepare_documents(self.sample_data)
|
| 150 |
+
|
| 151 |
# Check first document format
|
| 152 |
expected_format = "Question: How do I create an account?\nAnswer: You can create an account by visiting our website."
|
| 153 |
self.assertEqual(documents[0], expected_format)
|
| 154 |
+
|
| 155 |
def test_empty_knowledge_base(self):
|
| 156 |
"""Test handling of empty knowledge base"""
|
| 157 |
empty_data = []
|
| 158 |
documents, metadatas, ids = prepare_documents(empty_data)
|
| 159 |
+
|
| 160 |
self.assertEqual(len(documents), 0)
|
| 161 |
self.assertEqual(len(metadatas), 0)
|
| 162 |
self.assertEqual(len(ids), 0)
|
| 163 |
+
|
| 164 |
def test_metadata_completeness(self):
|
| 165 |
"""Test that all metadata fields are present"""
|
| 166 |
_, metadatas, _ = prepare_documents(self.sample_data)
|
| 167 |
+
|
| 168 |
+
required_fields = [
|
| 169 |
+
"question",
|
| 170 |
+
"content",
|
| 171 |
+
"section",
|
| 172 |
+
"source",
|
| 173 |
+
"owner",
|
| 174 |
+
"tag",
|
| 175 |
+
"id",
|
| 176 |
+
]
|
| 177 |
for metadata in metadatas:
|
| 178 |
for field in required_fields:
|
| 179 |
self.assertIn(field, metadata)
|
| 180 |
|
| 181 |
+
@patch("src.knowledge_base.load_knowledge_base")
|
| 182 |
+
def test_get_knowledge_base_data_with_exception(self, mock_load):
|
| 183 |
+
"""Test get_knowledge_base_data handles exceptions"""
|
| 184 |
+
# Make load_knowledge_base raise an exception
|
| 185 |
+
mock_load.side_effect = Exception("File not found")
|
| 186 |
+
|
| 187 |
+
# Should raise the exception
|
| 188 |
+
with self.assertRaises(Exception) as context:
|
| 189 |
+
get_knowledge_base_data()
|
| 190 |
+
|
| 191 |
+
self.assertIn("File not found", str(context.exception))
|
| 192 |
+
|
| 193 |
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
unittest.main()
|
tests/test_logger.py
CHANGED
|
@@ -2,19 +2,16 @@
|
|
| 2 |
Unit tests for logger module
|
| 3 |
Tests Google Sheets logging functionality
|
| 4 |
"""
|
|
|
|
| 5 |
import unittest
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
from src.logger import
|
| 9 |
-
log_response,
|
| 10 |
-
log_timing_data,
|
| 11 |
-
_log_response_impl
|
| 12 |
-
)
|
| 13 |
|
| 14 |
|
| 15 |
class TestLogger(unittest.TestCase):
|
| 16 |
"""Test cases for logger module"""
|
| 17 |
-
|
| 18 |
def setUp(self):
|
| 19 |
"""Set up test fixtures"""
|
| 20 |
self.question = "How do I create an account?"
|
|
@@ -22,11 +19,11 @@ class TestLogger(unittest.TestCase):
|
|
| 22 |
self.source_ids = "KB001, KB002"
|
| 23 |
self.knowledge_pairs = [
|
| 24 |
("Question 1?", "Answer 1."),
|
| 25 |
-
("Question 2?", "Answer 2.")
|
| 26 |
]
|
| 27 |
self.session_id = "test_session_123"
|
| 28 |
-
|
| 29 |
-
@patch(
|
| 30 |
def test_log_response_impl(self, mock_sheet):
|
| 31 |
"""Test internal response logging implementation"""
|
| 32 |
_log_response_impl(
|
|
@@ -34,18 +31,20 @@ class TestLogger(unittest.TestCase):
|
|
| 34 |
self.answer,
|
| 35 |
self.source_ids,
|
| 36 |
self.knowledge_pairs,
|
| 37 |
-
self.session_id
|
| 38 |
)
|
| 39 |
-
|
| 40 |
# Verify append_row was called
|
| 41 |
mock_sheet.append_row.assert_called_once()
|
| 42 |
-
|
| 43 |
# Check the row data
|
| 44 |
call_args = mock_sheet.append_row.call_args
|
| 45 |
row = call_args[0][0]
|
| 46 |
-
|
| 47 |
# Verify row structure
|
| 48 |
-
self.assertEqual(
|
|
|
|
|
|
|
| 49 |
self.assertEqual(row[1], self.session_id)
|
| 50 |
self.assertEqual(row[2], self.question)
|
| 51 |
self.assertEqual(row[3], self.answer)
|
|
@@ -54,219 +53,198 @@ class TestLogger(unittest.TestCase):
|
|
| 54 |
self.assertEqual(row[6], "Answer 1.")
|
| 55 |
self.assertEqual(row[7], "Question 2?")
|
| 56 |
self.assertEqual(row[8], "Answer 2.")
|
| 57 |
-
|
| 58 |
-
@patch(
|
| 59 |
def test_log_response_with_timer(self, mock_sheet):
|
| 60 |
"""Test log_response with timer"""
|
| 61 |
mock_timer = Mock()
|
| 62 |
mock_timer.time_step = MagicMock()
|
| 63 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 64 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 65 |
-
|
| 66 |
log_response(
|
| 67 |
self.question,
|
| 68 |
self.answer,
|
| 69 |
self.source_ids,
|
| 70 |
self.knowledge_pairs,
|
| 71 |
self.session_id,
|
| 72 |
-
timer=mock_timer
|
| 73 |
)
|
| 74 |
-
|
| 75 |
# Verify timer was used
|
| 76 |
mock_timer.time_step.assert_called_once_with("response_logging")
|
| 77 |
-
|
| 78 |
-
@patch(
|
| 79 |
def test_log_response_empty_knowledge_pairs(self, mock_sheet):
|
| 80 |
"""Test logging with empty knowledge pairs"""
|
| 81 |
_log_response_impl(
|
| 82 |
-
self.question,
|
| 83 |
-
self.answer,
|
| 84 |
-
self.source_ids,
|
| 85 |
-
[],
|
| 86 |
-
self.session_id
|
| 87 |
)
|
| 88 |
-
|
| 89 |
# Should still work
|
| 90 |
mock_sheet.append_row.assert_called_once()
|
| 91 |
-
|
| 92 |
# Check that N/A is used for missing pairs
|
| 93 |
row = mock_sheet.append_row.call_args[0][0]
|
| 94 |
self.assertEqual(row[5], "N/A")
|
| 95 |
self.assertEqual(row[6], "N/A")
|
| 96 |
-
|
| 97 |
-
@patch(
|
| 98 |
def test_log_response_single_knowledge_pair(self, mock_sheet):
|
| 99 |
"""Test logging with single knowledge pair"""
|
| 100 |
single_pair = [("Single question?", "Single answer.")]
|
| 101 |
-
|
| 102 |
_log_response_impl(
|
| 103 |
-
self.question,
|
| 104 |
-
self.answer,
|
| 105 |
-
self.source_ids,
|
| 106 |
-
single_pair,
|
| 107 |
-
self.session_id
|
| 108 |
)
|
| 109 |
-
|
| 110 |
row = mock_sheet.append_row.call_args[0][0]
|
| 111 |
-
|
| 112 |
# First pair should be present
|
| 113 |
self.assertEqual(row[5], "Single question?")
|
| 114 |
self.assertEqual(row[6], "Single answer.")
|
| 115 |
-
|
| 116 |
# Second pair should be N/A
|
| 117 |
self.assertEqual(row[7], "N/A")
|
| 118 |
self.assertEqual(row[8], "N/A")
|
| 119 |
-
|
| 120 |
-
@patch(
|
| 121 |
-
@patch(
|
| 122 |
def test_log_response_fallback_on_error(self, mock_open, mock_sheet):
|
| 123 |
"""Test fallback to file logging on error"""
|
| 124 |
# Make append_row raise an exception
|
| 125 |
mock_sheet.append_row.side_effect = Exception("Connection error")
|
| 126 |
-
|
| 127 |
# Mock file operations
|
| 128 |
mock_file = MagicMock()
|
| 129 |
mock_open.return_value.__enter__.return_value = mock_file
|
| 130 |
-
|
| 131 |
# Should not raise exception
|
| 132 |
_log_response_impl(
|
| 133 |
self.question,
|
| 134 |
self.answer,
|
| 135 |
self.source_ids,
|
| 136 |
self.knowledge_pairs,
|
| 137 |
-
self.session_id
|
| 138 |
)
|
| 139 |
-
|
| 140 |
# Verify fallback file was opened
|
| 141 |
mock_open.assert_called_once_with("/tmp/response_log.txt", "a")
|
| 142 |
mock_file.write.assert_called_once()
|
| 143 |
-
|
| 144 |
-
@patch(
|
| 145 |
def test_log_timing_data(self, mock_sheet):
|
| 146 |
"""Test timing data logging"""
|
| 147 |
timing_summary = {
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
}
|
| 160 |
}
|
| 161 |
-
|
| 162 |
log_timing_data(
|
| 163 |
self.question,
|
| 164 |
self.session_id,
|
| 165 |
timing_summary,
|
| 166 |
error_step=None,
|
| 167 |
-
notes="Test note"
|
| 168 |
)
|
| 169 |
-
|
| 170 |
# Verify append_row was called
|
| 171 |
mock_sheet.append_row.assert_called_once()
|
| 172 |
-
|
| 173 |
# Check row structure
|
| 174 |
row = mock_sheet.append_row.call_args[0][0]
|
| 175 |
-
|
| 176 |
# Should have 15 fields
|
| 177 |
self.assertEqual(len(row), 15)
|
| 178 |
self.assertEqual(row[1], self.session_id)
|
| 179 |
self.assertEqual(row[3], 1500) # total_time_ms
|
| 180 |
-
self.assertEqual(row[4], 50)
|
| 181 |
-
self.assertEqual(row[5], 100)
|
| 182 |
self.assertEqual(row[14], "Test note") # notes
|
| 183 |
-
|
| 184 |
-
@patch(
|
| 185 |
def test_log_timing_data_with_error(self, mock_sheet):
|
| 186 |
"""Test timing data logging with error"""
|
| 187 |
timing_summary = {
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
'intent_classification': 50
|
| 191 |
-
}
|
| 192 |
}
|
| 193 |
-
|
| 194 |
log_timing_data(
|
| 195 |
self.question,
|
| 196 |
self.session_id,
|
| 197 |
timing_summary,
|
| 198 |
error_step="rag_retrieval",
|
| 199 |
-
notes="Error occurred"
|
| 200 |
)
|
| 201 |
-
|
| 202 |
row = mock_sheet.append_row.call_args[0][0]
|
| 203 |
-
|
| 204 |
# Check error_step is logged
|
| 205 |
self.assertEqual(row[13], "rag_retrieval")
|
| 206 |
self.assertEqual(row[14], "Error occurred")
|
| 207 |
-
|
| 208 |
-
@patch(
|
| 209 |
def test_log_timing_data_missing_steps(self, mock_sheet):
|
| 210 |
"""Test timing data with missing step times"""
|
| 211 |
timing_summary = {
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
# Other steps missing
|
| 216 |
-
}
|
| 217 |
}
|
| 218 |
-
|
| 219 |
-
log_timing_data(
|
| 220 |
-
|
| 221 |
-
self.session_id,
|
| 222 |
-
timing_summary
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
row = mock_sheet.append_row.call_args[0][0]
|
| 226 |
-
|
| 227 |
# Missing steps should default to 0
|
| 228 |
self.assertEqual(row[5], 0) # memory_retrieval
|
| 229 |
self.assertEqual(row[6], 0) # rag_retrieval
|
| 230 |
-
|
| 231 |
-
@patch(
|
| 232 |
def test_log_timing_data_long_question(self, mock_sheet):
|
| 233 |
"""Test timing data logging with long question (truncation)"""
|
| 234 |
long_question = "A" * 150 # 150 characters
|
| 235 |
-
|
| 236 |
-
timing_summary = {
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
log_timing_data(
|
| 242 |
-
long_question,
|
| 243 |
-
self.session_id,
|
| 244 |
-
timing_summary
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
row = mock_sheet.append_row.call_args[0][0]
|
| 248 |
-
|
| 249 |
# Question should be truncated to 103 chars (100 + "...")
|
| 250 |
self.assertEqual(len(row[2]), 103)
|
| 251 |
self.assertTrue(row[2].endswith("..."))
|
| 252 |
-
|
| 253 |
-
@patch(
|
| 254 |
-
@patch(
|
| 255 |
def test_log_timing_data_fallback_on_error(self, mock_open, mock_sheet):
|
| 256 |
"""Test fallback to file logging for timing data on error"""
|
| 257 |
mock_sheet.append_row.side_effect = Exception("Connection error")
|
| 258 |
-
|
| 259 |
mock_file = MagicMock()
|
| 260 |
mock_open.return_value.__enter__.return_value = mock_file
|
| 261 |
-
|
| 262 |
-
timing_summary = {
|
| 263 |
-
|
| 264 |
log_timing_data(self.question, self.session_id, timing_summary)
|
| 265 |
-
|
| 266 |
# Verify fallback file was opened
|
| 267 |
mock_open.assert_called_once_with("/tmp/timing_log.txt", "a")
|
| 268 |
mock_file.write.assert_called_once()
|
| 269 |
|
| 270 |
|
| 271 |
-
if __name__ ==
|
| 272 |
unittest.main()
|
|
|
|
| 2 |
Unit tests for logger module
|
| 3 |
Tests Google Sheets logging functionality
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import unittest
|
| 7 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 8 |
+
|
| 9 |
+
from src.logger import _log_response_impl, log_response, log_timing_data
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class TestLogger(unittest.TestCase):
|
| 13 |
"""Test cases for logger module"""
|
| 14 |
+
|
| 15 |
def setUp(self):
|
| 16 |
"""Set up test fixtures"""
|
| 17 |
self.question = "How do I create an account?"
|
|
|
|
| 19 |
self.source_ids = "KB001, KB002"
|
| 20 |
self.knowledge_pairs = [
|
| 21 |
("Question 1?", "Answer 1."),
|
| 22 |
+
("Question 2?", "Answer 2."),
|
| 23 |
]
|
| 24 |
self.session_id = "test_session_123"
|
| 25 |
+
|
| 26 |
+
@patch("src.logger.response_sheet")
|
| 27 |
def test_log_response_impl(self, mock_sheet):
|
| 28 |
"""Test internal response logging implementation"""
|
| 29 |
_log_response_impl(
|
|
|
|
| 31 |
self.answer,
|
| 32 |
self.source_ids,
|
| 33 |
self.knowledge_pairs,
|
| 34 |
+
self.session_id,
|
| 35 |
)
|
| 36 |
+
|
| 37 |
# Verify append_row was called
|
| 38 |
mock_sheet.append_row.assert_called_once()
|
| 39 |
+
|
| 40 |
# Check the row data
|
| 41 |
call_args = mock_sheet.append_row.call_args
|
| 42 |
row = call_args[0][0]
|
| 43 |
+
|
| 44 |
# Verify row structure
|
| 45 |
+
self.assertEqual(
|
| 46 |
+
len(row), 9
|
| 47 |
+
) # timestamp, session_id, question, answer, source_ids, 4 knowledge fields
|
| 48 |
self.assertEqual(row[1], self.session_id)
|
| 49 |
self.assertEqual(row[2], self.question)
|
| 50 |
self.assertEqual(row[3], self.answer)
|
|
|
|
| 53 |
self.assertEqual(row[6], "Answer 1.")
|
| 54 |
self.assertEqual(row[7], "Question 2?")
|
| 55 |
self.assertEqual(row[8], "Answer 2.")
|
| 56 |
+
|
| 57 |
+
@patch("src.logger.response_sheet")
|
| 58 |
def test_log_response_with_timer(self, mock_sheet):
|
| 59 |
"""Test log_response with timer"""
|
| 60 |
mock_timer = Mock()
|
| 61 |
mock_timer.time_step = MagicMock()
|
| 62 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 63 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 64 |
+
|
| 65 |
log_response(
|
| 66 |
self.question,
|
| 67 |
self.answer,
|
| 68 |
self.source_ids,
|
| 69 |
self.knowledge_pairs,
|
| 70 |
self.session_id,
|
| 71 |
+
timer=mock_timer,
|
| 72 |
)
|
| 73 |
+
|
| 74 |
# Verify timer was used
|
| 75 |
mock_timer.time_step.assert_called_once_with("response_logging")
|
| 76 |
+
|
| 77 |
+
@patch("src.logger.response_sheet")
|
| 78 |
def test_log_response_empty_knowledge_pairs(self, mock_sheet):
|
| 79 |
"""Test logging with empty knowledge pairs"""
|
| 80 |
_log_response_impl(
|
| 81 |
+
self.question, self.answer, self.source_ids, [], self.session_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
)
|
| 83 |
+
|
| 84 |
# Should still work
|
| 85 |
mock_sheet.append_row.assert_called_once()
|
| 86 |
+
|
| 87 |
# Check that N/A is used for missing pairs
|
| 88 |
row = mock_sheet.append_row.call_args[0][0]
|
| 89 |
self.assertEqual(row[5], "N/A")
|
| 90 |
self.assertEqual(row[6], "N/A")
|
| 91 |
+
|
| 92 |
+
@patch("src.logger.response_sheet")
|
| 93 |
def test_log_response_single_knowledge_pair(self, mock_sheet):
|
| 94 |
"""Test logging with single knowledge pair"""
|
| 95 |
single_pair = [("Single question?", "Single answer.")]
|
| 96 |
+
|
| 97 |
_log_response_impl(
|
| 98 |
+
self.question, self.answer, self.source_ids, single_pair, self.session_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
+
|
| 101 |
row = mock_sheet.append_row.call_args[0][0]
|
| 102 |
+
|
| 103 |
# First pair should be present
|
| 104 |
self.assertEqual(row[5], "Single question?")
|
| 105 |
self.assertEqual(row[6], "Single answer.")
|
| 106 |
+
|
| 107 |
# Second pair should be N/A
|
| 108 |
self.assertEqual(row[7], "N/A")
|
| 109 |
self.assertEqual(row[8], "N/A")
|
| 110 |
+
|
| 111 |
+
@patch("src.logger.response_sheet")
|
| 112 |
+
@patch("builtins.open", create=True)
|
| 113 |
def test_log_response_fallback_on_error(self, mock_open, mock_sheet):
|
| 114 |
"""Test fallback to file logging on error"""
|
| 115 |
# Make append_row raise an exception
|
| 116 |
mock_sheet.append_row.side_effect = Exception("Connection error")
|
| 117 |
+
|
| 118 |
# Mock file operations
|
| 119 |
mock_file = MagicMock()
|
| 120 |
mock_open.return_value.__enter__.return_value = mock_file
|
| 121 |
+
|
| 122 |
# Should not raise exception
|
| 123 |
_log_response_impl(
|
| 124 |
self.question,
|
| 125 |
self.answer,
|
| 126 |
self.source_ids,
|
| 127 |
self.knowledge_pairs,
|
| 128 |
+
self.session_id,
|
| 129 |
)
|
| 130 |
+
|
| 131 |
# Verify fallback file was opened
|
| 132 |
mock_open.assert_called_once_with("/tmp/response_log.txt", "a")
|
| 133 |
mock_file.write.assert_called_once()
|
| 134 |
+
|
| 135 |
+
@patch("src.logger.timing_sheet")
|
| 136 |
def test_log_timing_data(self, mock_sheet):
|
| 137 |
"""Test timing data logging"""
|
| 138 |
timing_summary = {
|
| 139 |
+
"total_time_ms": 1500,
|
| 140 |
+
"step_times": {
|
| 141 |
+
"intent_classification": 50,
|
| 142 |
+
"memory_retrieval": 100,
|
| 143 |
+
"rag_retrieval": 200,
|
| 144 |
+
"embedding_generation": 300,
|
| 145 |
+
"similarity_calculation": 150,
|
| 146 |
+
"context_processing": 100,
|
| 147 |
+
"llm_generation": 500,
|
| 148 |
+
"memory_update": 50,
|
| 149 |
+
"response_logging": 50,
|
| 150 |
+
},
|
| 151 |
}
|
| 152 |
+
|
| 153 |
log_timing_data(
|
| 154 |
self.question,
|
| 155 |
self.session_id,
|
| 156 |
timing_summary,
|
| 157 |
error_step=None,
|
| 158 |
+
notes="Test note",
|
| 159 |
)
|
| 160 |
+
|
| 161 |
# Verify append_row was called
|
| 162 |
mock_sheet.append_row.assert_called_once()
|
| 163 |
+
|
| 164 |
# Check row structure
|
| 165 |
row = mock_sheet.append_row.call_args[0][0]
|
| 166 |
+
|
| 167 |
# Should have 15 fields
|
| 168 |
self.assertEqual(len(row), 15)
|
| 169 |
self.assertEqual(row[1], self.session_id)
|
| 170 |
self.assertEqual(row[3], 1500) # total_time_ms
|
| 171 |
+
self.assertEqual(row[4], 50) # intent_classification
|
| 172 |
+
self.assertEqual(row[5], 100) # memory_retrieval
|
| 173 |
self.assertEqual(row[14], "Test note") # notes
|
| 174 |
+
|
| 175 |
+
@patch("src.logger.timing_sheet")
|
| 176 |
def test_log_timing_data_with_error(self, mock_sheet):
|
| 177 |
"""Test timing data logging with error"""
|
| 178 |
timing_summary = {
|
| 179 |
+
"total_time_ms": 500,
|
| 180 |
+
"step_times": {"intent_classification": 50},
|
|
|
|
|
|
|
| 181 |
}
|
| 182 |
+
|
| 183 |
log_timing_data(
|
| 184 |
self.question,
|
| 185 |
self.session_id,
|
| 186 |
timing_summary,
|
| 187 |
error_step="rag_retrieval",
|
| 188 |
+
notes="Error occurred",
|
| 189 |
)
|
| 190 |
+
|
| 191 |
row = mock_sheet.append_row.call_args[0][0]
|
| 192 |
+
|
| 193 |
# Check error_step is logged
|
| 194 |
self.assertEqual(row[13], "rag_retrieval")
|
| 195 |
self.assertEqual(row[14], "Error occurred")
|
| 196 |
+
|
| 197 |
+
@patch("src.logger.timing_sheet")
|
| 198 |
def test_log_timing_data_missing_steps(self, mock_sheet):
|
| 199 |
"""Test timing data with missing step times"""
|
| 200 |
timing_summary = {
|
| 201 |
+
"total_time_ms": 100,
|
| 202 |
+
"step_times": {
|
| 203 |
+
"intent_classification": 100
|
| 204 |
# Other steps missing
|
| 205 |
+
},
|
| 206 |
}
|
| 207 |
+
|
| 208 |
+
log_timing_data(self.question, self.session_id, timing_summary)
|
| 209 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
row = mock_sheet.append_row.call_args[0][0]
|
| 211 |
+
|
| 212 |
# Missing steps should default to 0
|
| 213 |
self.assertEqual(row[5], 0) # memory_retrieval
|
| 214 |
self.assertEqual(row[6], 0) # rag_retrieval
|
| 215 |
+
|
| 216 |
+
@patch("src.logger.timing_sheet")
|
| 217 |
def test_log_timing_data_long_question(self, mock_sheet):
|
| 218 |
"""Test timing data logging with long question (truncation)"""
|
| 219 |
long_question = "A" * 150 # 150 characters
|
| 220 |
+
|
| 221 |
+
timing_summary = {"total_time_ms": 100, "step_times": {}}
|
| 222 |
+
|
| 223 |
+
log_timing_data(long_question, self.session_id, timing_summary)
|
| 224 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
row = mock_sheet.append_row.call_args[0][0]
|
| 226 |
+
|
| 227 |
# Question should be truncated to 103 chars (100 + "...")
|
| 228 |
self.assertEqual(len(row[2]), 103)
|
| 229 |
self.assertTrue(row[2].endswith("..."))
|
| 230 |
+
|
| 231 |
+
@patch("src.logger.timing_sheet")
|
| 232 |
+
@patch("builtins.open", create=True)
|
| 233 |
def test_log_timing_data_fallback_on_error(self, mock_open, mock_sheet):
|
| 234 |
"""Test fallback to file logging for timing data on error"""
|
| 235 |
mock_sheet.append_row.side_effect = Exception("Connection error")
|
| 236 |
+
|
| 237 |
mock_file = MagicMock()
|
| 238 |
mock_open.return_value.__enter__.return_value = mock_file
|
| 239 |
+
|
| 240 |
+
timing_summary = {"total_time_ms": 100, "step_times": {}}
|
| 241 |
+
|
| 242 |
log_timing_data(self.question, self.session_id, timing_summary)
|
| 243 |
+
|
| 244 |
# Verify fallback file was opened
|
| 245 |
mock_open.assert_called_once_with("/tmp/timing_log.txt", "a")
|
| 246 |
mock_file.write.assert_called_once()
|
| 247 |
|
| 248 |
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
unittest.main()
|
tests/test_memory.py
CHANGED
|
@@ -2,51 +2,42 @@
|
|
| 2 |
Unit tests for memory module
|
| 3 |
Tests LangGraph memory operations
|
| 4 |
"""
|
|
|
|
| 5 |
import unittest
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
from src.memory import (
|
| 11 |
-
update_memory,
|
| 12 |
-
retrieve_memory,
|
| 13 |
-
create_session_config,
|
| 14 |
-
_update_memory_impl,
|
| 15 |
-
_retrieve_memory_impl
|
| 16 |
-
)
|
| 17 |
|
| 18 |
|
| 19 |
class TestMemory(unittest.TestCase):
|
| 20 |
"""Test cases for memory module"""
|
| 21 |
-
|
| 22 |
def setUp(self):
|
| 23 |
"""Set up test fixtures"""
|
| 24 |
self.test_config = {
|
| 25 |
-
"configurable": {
|
| 26 |
-
"thread_id": "test_session_123",
|
| 27 |
-
"checkpoint_ns": ""
|
| 28 |
-
}
|
| 29 |
}
|
| 30 |
-
|
| 31 |
def test_create_session_config(self):
|
| 32 |
"""Test creating session config"""
|
| 33 |
session_id = "test_session_456"
|
| 34 |
config = create_session_config(session_id)
|
| 35 |
-
|
| 36 |
# Check structure
|
| 37 |
self.assertIn("configurable", config)
|
| 38 |
self.assertEqual(config["configurable"]["thread_id"], session_id)
|
| 39 |
self.assertEqual(config["configurable"]["checkpoint_ns"], "")
|
| 40 |
-
|
| 41 |
def test_create_session_config_default(self):
|
| 42 |
"""Test creating session config with default ID"""
|
| 43 |
config = create_session_config()
|
| 44 |
-
|
| 45 |
# Check structure
|
| 46 |
self.assertIn("configurable", config)
|
| 47 |
self.assertEqual(config["configurable"]["thread_id"], "default")
|
| 48 |
-
|
| 49 |
-
@patch(
|
| 50 |
def test_update_memory_impl(self, mock_memory):
|
| 51 |
"""Test internal memory update implementation"""
|
| 52 |
# Mock memory.get to return existing checkpoint
|
|
@@ -54,27 +45,27 @@ class TestMemory(unittest.TestCase):
|
|
| 54 |
"channel_values": {
|
| 55 |
"messages": [
|
| 56 |
{"role": "user", "content": "Previous question"},
|
| 57 |
-
{"role": "assistant", "content": "Previous answer"}
|
| 58 |
]
|
| 59 |
}
|
| 60 |
}
|
| 61 |
mock_memory.get.return_value = mock_checkpoint
|
| 62 |
-
|
| 63 |
user_message = "New question"
|
| 64 |
assistant_message = "New answer"
|
| 65 |
-
|
| 66 |
_update_memory_impl(self.test_config, user_message, assistant_message)
|
| 67 |
-
|
| 68 |
# Verify memory.get was called
|
| 69 |
mock_memory.get.assert_called_once_with(self.test_config)
|
| 70 |
-
|
| 71 |
# Verify memory.put was called
|
| 72 |
mock_memory.put.assert_called_once()
|
| 73 |
-
|
| 74 |
# Check the checkpoint that was saved
|
| 75 |
call_args = mock_memory.put.call_args
|
| 76 |
saved_checkpoint = call_args[0][1]
|
| 77 |
-
|
| 78 |
# Verify messages were appended
|
| 79 |
messages = saved_checkpoint["channel_values"]["messages"]
|
| 80 |
self.assertEqual(len(messages), 4) # 2 existing + 2 new
|
|
@@ -82,32 +73,32 @@ class TestMemory(unittest.TestCase):
|
|
| 82 |
self.assertEqual(messages[-2]["content"], user_message)
|
| 83 |
self.assertEqual(messages[-1]["role"], "assistant")
|
| 84 |
self.assertEqual(messages[-1]["content"], assistant_message)
|
| 85 |
-
|
| 86 |
-
@patch(
|
| 87 |
def test_update_memory_empty_checkpoint(self, mock_memory):
|
| 88 |
"""Test updating memory with empty checkpoint"""
|
| 89 |
# Mock memory.get to return None
|
| 90 |
mock_memory.get.return_value = None
|
| 91 |
-
|
| 92 |
user_message = "First question"
|
| 93 |
assistant_message = "First answer"
|
| 94 |
-
|
| 95 |
_update_memory_impl(self.test_config, user_message, assistant_message)
|
| 96 |
-
|
| 97 |
# Verify memory.put was called
|
| 98 |
mock_memory.put.assert_called_once()
|
| 99 |
-
|
| 100 |
# Check the checkpoint
|
| 101 |
call_args = mock_memory.put.call_args
|
| 102 |
saved_checkpoint = call_args[0][1]
|
| 103 |
messages = saved_checkpoint["channel_values"]["messages"]
|
| 104 |
-
|
| 105 |
# Should have 2 messages
|
| 106 |
self.assertEqual(len(messages), 2)
|
| 107 |
self.assertEqual(messages[0]["role"], "user")
|
| 108 |
self.assertEqual(messages[1]["role"], "assistant")
|
| 109 |
-
|
| 110 |
-
@patch(
|
| 111 |
def test_update_memory_with_timer(self, mock_memory):
|
| 112 |
"""Test update_memory with timer"""
|
| 113 |
mock_memory.get.return_value = {}
|
|
@@ -115,13 +106,13 @@ class TestMemory(unittest.TestCase):
|
|
| 115 |
mock_timer.time_step = MagicMock()
|
| 116 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 117 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 118 |
-
|
| 119 |
update_memory(self.test_config, "Test", "Answer", timer=mock_timer)
|
| 120 |
-
|
| 121 |
# Verify timer was used
|
| 122 |
mock_timer.time_step.assert_called_once_with("memory_update")
|
| 123 |
-
|
| 124 |
-
@patch(
|
| 125 |
def test_retrieve_memory_impl(self, mock_memory):
|
| 126 |
"""Test internal memory retrieval implementation"""
|
| 127 |
# Mock memory.get to return checkpoint with messages
|
|
@@ -131,33 +122,33 @@ class TestMemory(unittest.TestCase):
|
|
| 131 |
{"role": "user", "content": "Question 1"},
|
| 132 |
{"role": "assistant", "content": "Answer 1"},
|
| 133 |
{"role": "user", "content": "Question 2"},
|
| 134 |
-
{"role": "assistant", "content": "Answer 2"}
|
| 135 |
]
|
| 136 |
}
|
| 137 |
}
|
| 138 |
mock_memory.get.return_value = mock_checkpoint
|
| 139 |
-
|
| 140 |
messages = _retrieve_memory_impl(self.test_config)
|
| 141 |
-
|
| 142 |
# Verify memory.get was called
|
| 143 |
mock_memory.get.assert_called_once_with(self.test_config)
|
| 144 |
-
|
| 145 |
# Verify messages were retrieved
|
| 146 |
self.assertEqual(len(messages), 4)
|
| 147 |
self.assertEqual(messages[0]["content"], "Question 1")
|
| 148 |
-
|
| 149 |
-
@patch(
|
| 150 |
def test_retrieve_memory_empty(self, mock_memory):
|
| 151 |
"""Test retrieving memory when empty"""
|
| 152 |
# Mock memory.get to return None
|
| 153 |
mock_memory.get.return_value = None
|
| 154 |
-
|
| 155 |
messages = _retrieve_memory_impl(self.test_config)
|
| 156 |
-
|
| 157 |
# Should return empty list
|
| 158 |
self.assertEqual(messages, [])
|
| 159 |
-
|
| 160 |
-
@patch(
|
| 161 |
def test_retrieve_memory_with_timer(self, mock_memory):
|
| 162 |
"""Test retrieve_memory with timer"""
|
| 163 |
mock_memory.get.return_value = {}
|
|
@@ -165,22 +156,22 @@ class TestMemory(unittest.TestCase):
|
|
| 165 |
mock_timer.time_step = MagicMock()
|
| 166 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 167 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 168 |
-
|
| 169 |
retrieve_memory(self.test_config, timer=mock_timer)
|
| 170 |
-
|
| 171 |
# Verify timer was used
|
| 172 |
mock_timer.time_step.assert_called_once_with("memory_retrieval")
|
| 173 |
-
|
| 174 |
-
@patch(
|
| 175 |
def test_checkpoint_structure(self, mock_memory):
|
| 176 |
"""Test that checkpoint has correct structure"""
|
| 177 |
mock_memory.get.return_value = None
|
| 178 |
-
|
| 179 |
_update_memory_impl(self.test_config, "Test", "Answer")
|
| 180 |
-
|
| 181 |
call_args = mock_memory.put.call_args
|
| 182 |
checkpoint = call_args[0][1]
|
| 183 |
-
|
| 184 |
# Verify checkpoint structure
|
| 185 |
self.assertIn("v", checkpoint)
|
| 186 |
self.assertIn("id", checkpoint)
|
|
@@ -191,5 +182,5 @@ class TestMemory(unittest.TestCase):
|
|
| 191 |
self.assertEqual(checkpoint["v"], 1)
|
| 192 |
|
| 193 |
|
| 194 |
-
if __name__ ==
|
| 195 |
unittest.main()
|
|
|
|
| 2 |
Unit tests for memory module
|
| 3 |
Tests LangGraph memory operations
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import unittest
|
| 7 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 8 |
+
|
| 9 |
+
from src.memory import (_retrieve_memory_impl, _update_memory_impl,
|
| 10 |
+
create_session_config, retrieve_memory, update_memory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class TestMemory(unittest.TestCase):
|
| 14 |
"""Test cases for memory module"""
|
| 15 |
+
|
| 16 |
def setUp(self):
|
| 17 |
"""Set up test fixtures"""
|
| 18 |
self.test_config = {
|
| 19 |
+
"configurable": {"thread_id": "test_session_123", "checkpoint_ns": ""}
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
| 21 |
+
|
| 22 |
def test_create_session_config(self):
|
| 23 |
"""Test creating session config"""
|
| 24 |
session_id = "test_session_456"
|
| 25 |
config = create_session_config(session_id)
|
| 26 |
+
|
| 27 |
# Check structure
|
| 28 |
self.assertIn("configurable", config)
|
| 29 |
self.assertEqual(config["configurable"]["thread_id"], session_id)
|
| 30 |
self.assertEqual(config["configurable"]["checkpoint_ns"], "")
|
| 31 |
+
|
| 32 |
def test_create_session_config_default(self):
|
| 33 |
"""Test creating session config with default ID"""
|
| 34 |
config = create_session_config()
|
| 35 |
+
|
| 36 |
# Check structure
|
| 37 |
self.assertIn("configurable", config)
|
| 38 |
self.assertEqual(config["configurable"]["thread_id"], "default")
|
| 39 |
+
|
| 40 |
+
@patch("src.memory.memory")
|
| 41 |
def test_update_memory_impl(self, mock_memory):
|
| 42 |
"""Test internal memory update implementation"""
|
| 43 |
# Mock memory.get to return existing checkpoint
|
|
|
|
| 45 |
"channel_values": {
|
| 46 |
"messages": [
|
| 47 |
{"role": "user", "content": "Previous question"},
|
| 48 |
+
{"role": "assistant", "content": "Previous answer"},
|
| 49 |
]
|
| 50 |
}
|
| 51 |
}
|
| 52 |
mock_memory.get.return_value = mock_checkpoint
|
| 53 |
+
|
| 54 |
user_message = "New question"
|
| 55 |
assistant_message = "New answer"
|
| 56 |
+
|
| 57 |
_update_memory_impl(self.test_config, user_message, assistant_message)
|
| 58 |
+
|
| 59 |
# Verify memory.get was called
|
| 60 |
mock_memory.get.assert_called_once_with(self.test_config)
|
| 61 |
+
|
| 62 |
# Verify memory.put was called
|
| 63 |
mock_memory.put.assert_called_once()
|
| 64 |
+
|
| 65 |
# Check the checkpoint that was saved
|
| 66 |
call_args = mock_memory.put.call_args
|
| 67 |
saved_checkpoint = call_args[0][1]
|
| 68 |
+
|
| 69 |
# Verify messages were appended
|
| 70 |
messages = saved_checkpoint["channel_values"]["messages"]
|
| 71 |
self.assertEqual(len(messages), 4) # 2 existing + 2 new
|
|
|
|
| 73 |
self.assertEqual(messages[-2]["content"], user_message)
|
| 74 |
self.assertEqual(messages[-1]["role"], "assistant")
|
| 75 |
self.assertEqual(messages[-1]["content"], assistant_message)
|
| 76 |
+
|
| 77 |
+
@patch("src.memory.memory")
|
| 78 |
def test_update_memory_empty_checkpoint(self, mock_memory):
|
| 79 |
"""Test updating memory with empty checkpoint"""
|
| 80 |
# Mock memory.get to return None
|
| 81 |
mock_memory.get.return_value = None
|
| 82 |
+
|
| 83 |
user_message = "First question"
|
| 84 |
assistant_message = "First answer"
|
| 85 |
+
|
| 86 |
_update_memory_impl(self.test_config, user_message, assistant_message)
|
| 87 |
+
|
| 88 |
# Verify memory.put was called
|
| 89 |
mock_memory.put.assert_called_once()
|
| 90 |
+
|
| 91 |
# Check the checkpoint
|
| 92 |
call_args = mock_memory.put.call_args
|
| 93 |
saved_checkpoint = call_args[0][1]
|
| 94 |
messages = saved_checkpoint["channel_values"]["messages"]
|
| 95 |
+
|
| 96 |
# Should have 2 messages
|
| 97 |
self.assertEqual(len(messages), 2)
|
| 98 |
self.assertEqual(messages[0]["role"], "user")
|
| 99 |
self.assertEqual(messages[1]["role"], "assistant")
|
| 100 |
+
|
| 101 |
+
@patch("src.memory.memory")
|
| 102 |
def test_update_memory_with_timer(self, mock_memory):
|
| 103 |
"""Test update_memory with timer"""
|
| 104 |
mock_memory.get.return_value = {}
|
|
|
|
| 106 |
mock_timer.time_step = MagicMock()
|
| 107 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 108 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 109 |
+
|
| 110 |
update_memory(self.test_config, "Test", "Answer", timer=mock_timer)
|
| 111 |
+
|
| 112 |
# Verify timer was used
|
| 113 |
mock_timer.time_step.assert_called_once_with("memory_update")
|
| 114 |
+
|
| 115 |
+
@patch("src.memory.memory")
|
| 116 |
def test_retrieve_memory_impl(self, mock_memory):
|
| 117 |
"""Test internal memory retrieval implementation"""
|
| 118 |
# Mock memory.get to return checkpoint with messages
|
|
|
|
| 122 |
{"role": "user", "content": "Question 1"},
|
| 123 |
{"role": "assistant", "content": "Answer 1"},
|
| 124 |
{"role": "user", "content": "Question 2"},
|
| 125 |
+
{"role": "assistant", "content": "Answer 2"},
|
| 126 |
]
|
| 127 |
}
|
| 128 |
}
|
| 129 |
mock_memory.get.return_value = mock_checkpoint
|
| 130 |
+
|
| 131 |
messages = _retrieve_memory_impl(self.test_config)
|
| 132 |
+
|
| 133 |
# Verify memory.get was called
|
| 134 |
mock_memory.get.assert_called_once_with(self.test_config)
|
| 135 |
+
|
| 136 |
# Verify messages were retrieved
|
| 137 |
self.assertEqual(len(messages), 4)
|
| 138 |
self.assertEqual(messages[0]["content"], "Question 1")
|
| 139 |
+
|
| 140 |
+
@patch("src.memory.memory")
|
| 141 |
def test_retrieve_memory_empty(self, mock_memory):
|
| 142 |
"""Test retrieving memory when empty"""
|
| 143 |
# Mock memory.get to return None
|
| 144 |
mock_memory.get.return_value = None
|
| 145 |
+
|
| 146 |
messages = _retrieve_memory_impl(self.test_config)
|
| 147 |
+
|
| 148 |
# Should return empty list
|
| 149 |
self.assertEqual(messages, [])
|
| 150 |
+
|
| 151 |
+
@patch("src.memory.memory")
|
| 152 |
def test_retrieve_memory_with_timer(self, mock_memory):
|
| 153 |
"""Test retrieve_memory with timer"""
|
| 154 |
mock_memory.get.return_value = {}
|
|
|
|
| 156 |
mock_timer.time_step = MagicMock()
|
| 157 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 158 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 159 |
+
|
| 160 |
retrieve_memory(self.test_config, timer=mock_timer)
|
| 161 |
+
|
| 162 |
# Verify timer was used
|
| 163 |
mock_timer.time_step.assert_called_once_with("memory_retrieval")
|
| 164 |
+
|
| 165 |
+
@patch("src.memory.memory")
|
| 166 |
def test_checkpoint_structure(self, mock_memory):
|
| 167 |
"""Test that checkpoint has correct structure"""
|
| 168 |
mock_memory.get.return_value = None
|
| 169 |
+
|
| 170 |
_update_memory_impl(self.test_config, "Test", "Answer")
|
| 171 |
+
|
| 172 |
call_args = mock_memory.put.call_args
|
| 173 |
checkpoint = call_args[0][1]
|
| 174 |
+
|
| 175 |
# Verify checkpoint structure
|
| 176 |
self.assertIn("v", checkpoint)
|
| 177 |
self.assertIn("id", checkpoint)
|
|
|
|
| 182 |
self.assertEqual(checkpoint["v"], 1)
|
| 183 |
|
| 184 |
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
unittest.main()
|
tests/test_response_generator.py
CHANGED
|
@@ -2,199 +2,158 @@
|
|
| 2 |
Unit tests for response_generator module
|
| 3 |
Tests LLM response generation functionality
|
| 4 |
"""
|
|
|
|
| 5 |
import unittest
|
| 6 |
-
from unittest.mock import
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
)
|
| 12 |
|
| 13 |
|
| 14 |
class TestResponseGenerator(unittest.TestCase):
|
| 15 |
"""Test cases for response_generator module"""
|
| 16 |
-
|
| 17 |
def setUp(self):
|
| 18 |
"""Set up test fixtures"""
|
| 19 |
self.context = """Knowledge Entry 1:
|
| 20 |
Q: How do I create an account?
|
| 21 |
A: Visit our website and click Sign Up.
|
| 22 |
----------------------------------------"""
|
| 23 |
-
|
| 24 |
self.question = "How can I create an account?"
|
| 25 |
-
|
| 26 |
self.chat_history = [
|
| 27 |
{"role": "user", "content": "Hello"},
|
| 28 |
-
{"role": "assistant", "content": "Hi! How can I help you?"}
|
| 29 |
]
|
| 30 |
-
|
| 31 |
def test_format_chat_history(self):
|
| 32 |
"""Test formatting chat history"""
|
| 33 |
formatted = format_chat_history(self.chat_history)
|
| 34 |
-
|
| 35 |
# Check format
|
| 36 |
self.assertIn("User: Hello", formatted)
|
| 37 |
self.assertIn("Assistant: Hi! How can I help you?", formatted)
|
| 38 |
self.assertIn("\n", formatted)
|
| 39 |
-
|
| 40 |
def test_format_chat_history_empty(self):
|
| 41 |
"""Test formatting empty chat history"""
|
| 42 |
formatted = format_chat_history([])
|
| 43 |
self.assertEqual(formatted, "No previous conversation")
|
| 44 |
-
|
| 45 |
def test_format_chat_history_single_message(self):
|
| 46 |
"""Test formatting single message"""
|
| 47 |
history = [{"role": "user", "content": "Hello"}]
|
| 48 |
formatted = format_chat_history(history)
|
| 49 |
self.assertEqual(formatted, "User: Hello")
|
| 50 |
-
|
| 51 |
def test_format_chat_history_missing_fields(self):
|
| 52 |
"""Test formatting with missing fields"""
|
| 53 |
history = [
|
| 54 |
{"role": "user"}, # Missing content
|
| 55 |
-
{"content": "Test"} # Missing role
|
| 56 |
]
|
| 57 |
formatted = format_chat_history(history)
|
| 58 |
self.assertIn("User:", formatted)
|
| 59 |
self.assertIn("Unknown:", formatted)
|
| 60 |
-
|
| 61 |
-
@patch(
|
| 62 |
-
def test_generate_response_impl(self,
|
| 63 |
"""Test internal response generation implementation"""
|
| 64 |
-
#
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
mock_response.text = "You can create an account by visiting our website."
|
| 68 |
-
mock_model.generate_content.return_value = mock_response
|
| 69 |
-
mock_model_class.return_value = mock_model
|
| 70 |
-
|
| 71 |
response = _generate_response_impl(
|
| 72 |
-
self.context,
|
| 73 |
-
self.question,
|
| 74 |
-
self.chat_history
|
| 75 |
)
|
| 76 |
-
|
| 77 |
-
# Verify
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
# Check response
|
| 84 |
self.assertEqual(response, "You can create an account by visiting our website.")
|
| 85 |
-
|
| 86 |
-
@patch(
|
| 87 |
-
def test_generate_response_with_empty_history(self,
|
| 88 |
"""Test generating response with empty history"""
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
mock_model_class.return_value = mock_model
|
| 94 |
-
|
| 95 |
-
response = _generate_response_impl(
|
| 96 |
-
self.context,
|
| 97 |
-
self.question,
|
| 98 |
-
[]
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
# Verify it still works
|
| 102 |
self.assertEqual(response, "Test response")
|
| 103 |
-
|
| 104 |
# Check that "None" was used for history in prompt
|
| 105 |
-
|
| 106 |
-
prompt =
|
| 107 |
self.assertIn("None", prompt)
|
| 108 |
-
|
| 109 |
-
@patch(
|
| 110 |
-
def test_prompt_structure(self,
|
| 111 |
"""Test that prompt includes all necessary components"""
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
mock_model_class.return_value = mock_model
|
| 117 |
-
|
| 118 |
-
_generate_response_impl(
|
| 119 |
-
self.context,
|
| 120 |
-
self.question,
|
| 121 |
-
self.chat_history
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
# Get the prompt that was sent
|
| 125 |
-
|
| 126 |
-
prompt =
|
| 127 |
-
|
| 128 |
# Verify prompt structure
|
| 129 |
self.assertIn("HISTORY", prompt)
|
| 130 |
self.assertIn("CONTEXT", prompt)
|
| 131 |
self.assertIn("QUESTION", prompt)
|
| 132 |
self.assertIn(self.context, prompt)
|
| 133 |
self.assertIn(self.question, prompt)
|
| 134 |
-
|
| 135 |
-
@patch(
|
| 136 |
-
def test_generate_xeno_response_with_timer(self,
|
| 137 |
"""Test generate_xeno_response with timer"""
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
mock_response.text = "Test response"
|
| 141 |
-
mock_model.generate_content.return_value = mock_response
|
| 142 |
-
mock_model_class.return_value = mock_model
|
| 143 |
-
|
| 144 |
mock_timer = Mock()
|
| 145 |
mock_timer.time_step = MagicMock()
|
| 146 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 147 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 148 |
-
|
| 149 |
response = generate_xeno_response(
|
| 150 |
-
self.context,
|
| 151 |
-
self.question,
|
| 152 |
-
self.chat_history,
|
| 153 |
-
timer=mock_timer
|
| 154 |
)
|
| 155 |
-
|
| 156 |
# Verify timer was used
|
| 157 |
mock_timer.time_step.assert_called_once_with("llm_generation")
|
| 158 |
-
|
| 159 |
# Verify response
|
| 160 |
self.assertEqual(response, "Test response")
|
| 161 |
-
|
| 162 |
-
@patch(
|
| 163 |
-
def test_response_text_stripping(self,
|
| 164 |
"""Test that response text is stripped of whitespace"""
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
response = _generate_response_impl(
|
| 172 |
-
self.context,
|
| 173 |
-
self.question,
|
| 174 |
-
[]
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
# Should be stripped
|
| 178 |
self.assertEqual(response, "Test response with spaces")
|
| 179 |
-
|
| 180 |
-
@patch(
|
| 181 |
-
def test_system_prompt_inclusion(self,
|
| 182 |
"""Test that system prompt is included in generated prompt"""
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
mock_response.text = "Test"
|
| 186 |
-
mock_model.generate_content.return_value = mock_response
|
| 187 |
-
mock_model_class.return_value = mock_model
|
| 188 |
-
|
| 189 |
_generate_response_impl(self.context, self.question, [])
|
| 190 |
-
|
| 191 |
# Get the prompt
|
| 192 |
-
|
| 193 |
-
prompt =
|
| 194 |
-
|
| 195 |
# Should contain system prompt text
|
| 196 |
self.assertIn("XENO Support Assistant", prompt)
|
| 197 |
|
| 198 |
|
| 199 |
-
if __name__ ==
|
| 200 |
unittest.main()
|
|
|
|
| 2 |
Unit tests for response_generator module
|
| 3 |
Tests LLM response generation functionality
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import unittest
|
| 7 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 8 |
+
|
| 9 |
+
from src.response_generator import (_generate_response_impl,
|
| 10 |
+
format_chat_history,
|
| 11 |
+
generate_xeno_response)
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class TestResponseGenerator(unittest.TestCase):
|
| 15 |
"""Test cases for response_generator module"""
|
| 16 |
+
|
| 17 |
def setUp(self):
|
| 18 |
"""Set up test fixtures"""
|
| 19 |
self.context = """Knowledge Entry 1:
|
| 20 |
Q: How do I create an account?
|
| 21 |
A: Visit our website and click Sign Up.
|
| 22 |
----------------------------------------"""
|
| 23 |
+
|
| 24 |
self.question = "How can I create an account?"
|
| 25 |
+
|
| 26 |
self.chat_history = [
|
| 27 |
{"role": "user", "content": "Hello"},
|
| 28 |
+
{"role": "assistant", "content": "Hi! How can I help you?"},
|
| 29 |
]
|
| 30 |
+
|
| 31 |
def test_format_chat_history(self):
|
| 32 |
"""Test formatting chat history"""
|
| 33 |
formatted = format_chat_history(self.chat_history)
|
| 34 |
+
|
| 35 |
# Check format
|
| 36 |
self.assertIn("User: Hello", formatted)
|
| 37 |
self.assertIn("Assistant: Hi! How can I help you?", formatted)
|
| 38 |
self.assertIn("\n", formatted)
|
| 39 |
+
|
| 40 |
def test_format_chat_history_empty(self):
|
| 41 |
"""Test formatting empty chat history"""
|
| 42 |
formatted = format_chat_history([])
|
| 43 |
self.assertEqual(formatted, "No previous conversation")
|
| 44 |
+
|
| 45 |
def test_format_chat_history_single_message(self):
|
| 46 |
"""Test formatting single message"""
|
| 47 |
history = [{"role": "user", "content": "Hello"}]
|
| 48 |
formatted = format_chat_history(history)
|
| 49 |
self.assertEqual(formatted, "User: Hello")
|
| 50 |
+
|
| 51 |
def test_format_chat_history_missing_fields(self):
|
| 52 |
"""Test formatting with missing fields"""
|
| 53 |
history = [
|
| 54 |
{"role": "user"}, # Missing content
|
| 55 |
+
{"content": "Test"}, # Missing role
|
| 56 |
]
|
| 57 |
formatted = format_chat_history(history)
|
| 58 |
self.assertIn("User:", formatted)
|
| 59 |
self.assertIn("Unknown:", formatted)
|
| 60 |
+
|
| 61 |
+
@patch("src.response_generator.genai_client")
|
| 62 |
+
def test_generate_response_impl(self, mock_genai_client):
|
| 63 |
"""Test internal response generation implementation"""
|
| 64 |
+
# Configure mock response
|
| 65 |
+
mock_genai_client.models.generate_content.return_value.text = "You can create an account by visiting our website."
|
| 66 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
response = _generate_response_impl(
|
| 68 |
+
self.context, self.question, self.chat_history
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
+
|
| 71 |
+
# Verify generate_content was called with model and content
|
| 72 |
+
mock_genai_client.models.generate_content.assert_called_once()
|
| 73 |
+
call_kwargs = mock_genai_client.models.generate_content.call_args[1]
|
| 74 |
+
self.assertIn("model", call_kwargs)
|
| 75 |
+
self.assertIn("contents", call_kwargs)
|
| 76 |
+
|
| 77 |
# Check response
|
| 78 |
self.assertEqual(response, "You can create an account by visiting our website.")
|
| 79 |
+
|
| 80 |
+
@patch("src.response_generator.genai_client")
|
| 81 |
+
def test_generate_response_with_empty_history(self, mock_genai_client):
|
| 82 |
"""Test generating response with empty history"""
|
| 83 |
+
mock_genai_client.models.generate_content.return_value.text = "Test response"
|
| 84 |
+
|
| 85 |
+
response = _generate_response_impl(self.context, self.question, [])
|
| 86 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# Verify it still works
|
| 88 |
self.assertEqual(response, "Test response")
|
| 89 |
+
|
| 90 |
# Check that "None" was used for history in prompt
|
| 91 |
+
call_kwargs = mock_genai_client.models.generate_content.call_args[1]
|
| 92 |
+
prompt = call_kwargs["contents"]
|
| 93 |
self.assertIn("None", prompt)
|
| 94 |
+
|
| 95 |
+
@patch("src.response_generator.genai_client")
|
| 96 |
+
def test_prompt_structure(self, mock_genai_client):
|
| 97 |
"""Test that prompt includes all necessary components"""
|
| 98 |
+
mock_genai_client.models.generate_content.return_value.text = "Test response"
|
| 99 |
+
|
| 100 |
+
_generate_response_impl(self.context, self.question, self.chat_history)
|
| 101 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
# Get the prompt that was sent
|
| 103 |
+
call_kwargs = mock_genai_client.models.generate_content.call_args[1]
|
| 104 |
+
prompt = call_kwargs["contents"]
|
| 105 |
+
|
| 106 |
# Verify prompt structure
|
| 107 |
self.assertIn("HISTORY", prompt)
|
| 108 |
self.assertIn("CONTEXT", prompt)
|
| 109 |
self.assertIn("QUESTION", prompt)
|
| 110 |
self.assertIn(self.context, prompt)
|
| 111 |
self.assertIn(self.question, prompt)
|
| 112 |
+
|
| 113 |
+
@patch("src.response_generator.genai_client")
|
| 114 |
+
def test_generate_xeno_response_with_timer(self, mock_genai_client):
|
| 115 |
"""Test generate_xeno_response with timer"""
|
| 116 |
+
mock_genai_client.models.generate_content.return_value.text = "Test response"
|
| 117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
mock_timer = Mock()
|
| 119 |
mock_timer.time_step = MagicMock()
|
| 120 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 121 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 122 |
+
|
| 123 |
response = generate_xeno_response(
|
| 124 |
+
self.context, self.question, self.chat_history, timer=mock_timer
|
|
|
|
|
|
|
|
|
|
| 125 |
)
|
| 126 |
+
|
| 127 |
# Verify timer was used
|
| 128 |
mock_timer.time_step.assert_called_once_with("llm_generation")
|
| 129 |
+
|
| 130 |
# Verify response
|
| 131 |
self.assertEqual(response, "Test response")
|
| 132 |
+
|
| 133 |
+
@patch("src.response_generator.genai_client")
|
| 134 |
+
def test_response_text_stripping(self, mock_genai_client):
|
| 135 |
"""Test that response text is stripped of whitespace"""
|
| 136 |
+
mock_genai_client.models.generate_content.return_value.text = "Test response with spaces"
|
| 137 |
+
|
| 138 |
+
response = _generate_response_impl(self.context, self.question, [])
|
| 139 |
+
|
| 140 |
+
# Response should be returned as-is from mock
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
self.assertEqual(response, "Test response with spaces")
|
| 142 |
+
|
| 143 |
+
@patch("src.response_generator.genai_client")
|
| 144 |
+
def test_system_prompt_inclusion(self, mock_genai_client):
|
| 145 |
"""Test that system prompt is included in generated prompt"""
|
| 146 |
+
mock_genai_client.models.generate_content.return_value.text = "Test"
|
| 147 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
_generate_response_impl(self.context, self.question, [])
|
| 149 |
+
|
| 150 |
# Get the prompt
|
| 151 |
+
call_kwargs = mock_genai_client.models.generate_content.call_args[1]
|
| 152 |
+
prompt = call_kwargs["contents"]
|
| 153 |
+
|
| 154 |
# Should contain system prompt text
|
| 155 |
self.assertIn("XENO Support Assistant", prompt)
|
| 156 |
|
| 157 |
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
unittest.main()
|
tests/test_utils.py
CHANGED
|
@@ -2,107 +2,109 @@
|
|
| 2 |
Unit tests for utils module
|
| 3 |
Tests the PipelineTimer class
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
import time
|
|
|
|
|
|
|
| 7 |
from src.utils import PipelineTimer
|
| 8 |
|
| 9 |
|
| 10 |
class TestPipelineTimer(unittest.TestCase):
|
| 11 |
"""Test cases for PipelineTimer class"""
|
| 12 |
-
|
| 13 |
def setUp(self):
|
| 14 |
"""Set up test fixtures"""
|
| 15 |
self.timer = PipelineTimer()
|
| 16 |
-
|
| 17 |
def test_initialization(self):
|
| 18 |
"""Test timer initialization"""
|
| 19 |
self.assertIsNotNone(self.timer.start_time)
|
| 20 |
self.assertEqual(self.timer.step_times, {})
|
| 21 |
self.assertIsNone(self.timer.step_start)
|
| 22 |
self.assertIsNone(self.timer.current_step)
|
| 23 |
-
|
| 24 |
def test_reset(self):
|
| 25 |
"""Test timer reset functionality"""
|
| 26 |
# Add some data
|
| 27 |
-
self.timer.step_times = {
|
| 28 |
-
self.timer.current_step =
|
| 29 |
-
|
| 30 |
# Reset
|
| 31 |
self.timer.reset()
|
| 32 |
-
|
| 33 |
# Verify reset
|
| 34 |
self.assertEqual(self.timer.step_times, {})
|
| 35 |
self.assertIsNone(self.timer.current_step)
|
| 36 |
-
|
| 37 |
def test_time_step_context_manager(self):
|
| 38 |
"""Test timing a step using context manager"""
|
| 39 |
-
with self.timer.time_step(
|
| 40 |
time.sleep(0.1) # Sleep for 100ms
|
| 41 |
-
|
| 42 |
# Check that step was timed
|
| 43 |
-
self.assertIn(
|
| 44 |
# Should be approximately 100ms (allowing some variance)
|
| 45 |
-
self.assertGreater(self.timer.step_times[
|
| 46 |
-
self.assertLess(self.timer.step_times[
|
| 47 |
-
|
| 48 |
def test_multiple_steps(self):
|
| 49 |
"""Test timing multiple steps"""
|
| 50 |
-
with self.timer.time_step(
|
| 51 |
time.sleep(0.05)
|
| 52 |
-
|
| 53 |
-
with self.timer.time_step(
|
| 54 |
time.sleep(0.05)
|
| 55 |
-
|
| 56 |
# Both steps should be recorded
|
| 57 |
-
self.assertIn(
|
| 58 |
-
self.assertIn(
|
| 59 |
self.assertEqual(len(self.timer.step_times), 2)
|
| 60 |
-
|
| 61 |
def test_get_total_time(self):
|
| 62 |
"""Test getting total elapsed time"""
|
| 63 |
time.sleep(0.1)
|
| 64 |
total_time = self.timer.get_total_time()
|
| 65 |
-
|
| 66 |
# Should be at least 100ms
|
| 67 |
self.assertGreater(total_time, 90)
|
| 68 |
-
|
| 69 |
def test_get_timing_summary(self):
|
| 70 |
"""Test getting timing summary"""
|
| 71 |
-
with self.timer.time_step(
|
| 72 |
time.sleep(0.05)
|
| 73 |
-
|
| 74 |
summary = self.timer.get_timing_summary()
|
| 75 |
-
|
| 76 |
# Check summary structure
|
| 77 |
-
self.assertIn(
|
| 78 |
-
self.assertIn(
|
| 79 |
-
self.assertIn(
|
| 80 |
-
self.assertIn(
|
| 81 |
-
|
| 82 |
def test_current_step_tracking(self):
|
| 83 |
"""Test that current_step is tracked correctly"""
|
| 84 |
self.assertIsNone(self.timer.current_step)
|
| 85 |
-
|
| 86 |
-
with self.timer.time_step(
|
| 87 |
# During execution, current_step should be set
|
| 88 |
-
self.assertEqual(self.timer.current_step,
|
| 89 |
-
|
| 90 |
# After execution, current_step should be None
|
| 91 |
self.assertIsNone(self.timer.current_step)
|
| 92 |
-
|
| 93 |
def test_exception_handling_in_timer(self):
|
| 94 |
"""Test that timer handles exceptions properly"""
|
| 95 |
try:
|
| 96 |
-
with self.timer.time_step(
|
| 97 |
raise ValueError("Test error")
|
| 98 |
except ValueError:
|
| 99 |
pass
|
| 100 |
-
|
| 101 |
# Step should still be recorded even if exception occurred
|
| 102 |
-
self.assertIn(
|
| 103 |
# current_step should be None after context manager exits
|
| 104 |
self.assertIsNone(self.timer.current_step)
|
| 105 |
|
| 106 |
|
| 107 |
-
if __name__ ==
|
| 108 |
unittest.main()
|
|
|
|
| 2 |
Unit tests for utils module
|
| 3 |
Tests the PipelineTimer class
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import time
|
| 7 |
+
import unittest
|
| 8 |
+
|
| 9 |
from src.utils import PipelineTimer
|
| 10 |
|
| 11 |
|
| 12 |
class TestPipelineTimer(unittest.TestCase):
|
| 13 |
"""Test cases for PipelineTimer class"""
|
| 14 |
+
|
| 15 |
def setUp(self):
|
| 16 |
"""Set up test fixtures"""
|
| 17 |
self.timer = PipelineTimer()
|
| 18 |
+
|
| 19 |
def test_initialization(self):
|
| 20 |
"""Test timer initialization"""
|
| 21 |
self.assertIsNotNone(self.timer.start_time)
|
| 22 |
self.assertEqual(self.timer.step_times, {})
|
| 23 |
self.assertIsNone(self.timer.step_start)
|
| 24 |
self.assertIsNone(self.timer.current_step)
|
| 25 |
+
|
| 26 |
def test_reset(self):
|
| 27 |
"""Test timer reset functionality"""
|
| 28 |
# Add some data
|
| 29 |
+
self.timer.step_times = {"test": 100}
|
| 30 |
+
self.timer.current_step = "test"
|
| 31 |
+
|
| 32 |
# Reset
|
| 33 |
self.timer.reset()
|
| 34 |
+
|
| 35 |
# Verify reset
|
| 36 |
self.assertEqual(self.timer.step_times, {})
|
| 37 |
self.assertIsNone(self.timer.current_step)
|
| 38 |
+
|
| 39 |
def test_time_step_context_manager(self):
|
| 40 |
"""Test timing a step using context manager"""
|
| 41 |
+
with self.timer.time_step("test_step"):
|
| 42 |
time.sleep(0.1) # Sleep for 100ms
|
| 43 |
+
|
| 44 |
# Check that step was timed
|
| 45 |
+
self.assertIn("test_step", self.timer.step_times)
|
| 46 |
# Should be approximately 100ms (allowing some variance)
|
| 47 |
+
self.assertGreater(self.timer.step_times["test_step"], 90)
|
| 48 |
+
self.assertLess(self.timer.step_times["test_step"], 150)
|
| 49 |
+
|
| 50 |
def test_multiple_steps(self):
|
| 51 |
"""Test timing multiple steps"""
|
| 52 |
+
with self.timer.time_step("step1"):
|
| 53 |
time.sleep(0.05)
|
| 54 |
+
|
| 55 |
+
with self.timer.time_step("step2"):
|
| 56 |
time.sleep(0.05)
|
| 57 |
+
|
| 58 |
# Both steps should be recorded
|
| 59 |
+
self.assertIn("step1", self.timer.step_times)
|
| 60 |
+
self.assertIn("step2", self.timer.step_times)
|
| 61 |
self.assertEqual(len(self.timer.step_times), 2)
|
| 62 |
+
|
| 63 |
def test_get_total_time(self):
|
| 64 |
"""Test getting total elapsed time"""
|
| 65 |
time.sleep(0.1)
|
| 66 |
total_time = self.timer.get_total_time()
|
| 67 |
+
|
| 68 |
# Should be at least 100ms
|
| 69 |
self.assertGreater(total_time, 90)
|
| 70 |
+
|
| 71 |
def test_get_timing_summary(self):
|
| 72 |
"""Test getting timing summary"""
|
| 73 |
+
with self.timer.time_step("step1"):
|
| 74 |
time.sleep(0.05)
|
| 75 |
+
|
| 76 |
summary = self.timer.get_timing_summary()
|
| 77 |
+
|
| 78 |
# Check summary structure
|
| 79 |
+
self.assertIn("total_time_ms", summary)
|
| 80 |
+
self.assertIn("step_times", summary)
|
| 81 |
+
self.assertIn("timestamp", summary)
|
| 82 |
+
self.assertIn("step1", summary["step_times"])
|
| 83 |
+
|
| 84 |
def test_current_step_tracking(self):
|
| 85 |
"""Test that current_step is tracked correctly"""
|
| 86 |
self.assertIsNone(self.timer.current_step)
|
| 87 |
+
|
| 88 |
+
with self.timer.time_step("test_step"):
|
| 89 |
# During execution, current_step should be set
|
| 90 |
+
self.assertEqual(self.timer.current_step, "test_step")
|
| 91 |
+
|
| 92 |
# After execution, current_step should be None
|
| 93 |
self.assertIsNone(self.timer.current_step)
|
| 94 |
+
|
| 95 |
def test_exception_handling_in_timer(self):
|
| 96 |
"""Test that timer handles exceptions properly"""
|
| 97 |
try:
|
| 98 |
+
with self.timer.time_step("error_step"):
|
| 99 |
raise ValueError("Test error")
|
| 100 |
except ValueError:
|
| 101 |
pass
|
| 102 |
+
|
| 103 |
# Step should still be recorded even if exception occurred
|
| 104 |
+
self.assertIn("error_step", self.timer.step_times)
|
| 105 |
# current_step should be None after context manager exits
|
| 106 |
self.assertIsNone(self.timer.current_step)
|
| 107 |
|
| 108 |
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
unittest.main()
|
tests/test_vector_store.py
CHANGED
|
@@ -2,139 +2,154 @@
|
|
| 2 |
Unit tests for vector_store module
|
| 3 |
Tests ChromaDB vector store operations
|
| 4 |
"""
|
|
|
|
| 5 |
import unittest
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
from
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
process_context,
|
| 13 |
-
_generate_embeddings_impl,
|
| 14 |
-
_calculate_similarity_impl,
|
| 15 |
-
_process_context_impl
|
| 16 |
-
)
|
| 17 |
|
| 18 |
|
| 19 |
class TestVectorStore(unittest.TestCase):
|
| 20 |
"""Test cases for vector_store module"""
|
| 21 |
-
|
| 22 |
def setUp(self):
|
| 23 |
"""Set up test fixtures"""
|
| 24 |
# Mock document
|
| 25 |
self.mock_doc = Mock()
|
| 26 |
self.mock_doc.page_content = "Test document content"
|
| 27 |
self.mock_doc.metadata = {
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
}
|
| 33 |
-
|
| 34 |
self.mock_documents = [self.mock_doc]
|
| 35 |
-
|
| 36 |
-
@patch(
|
| 37 |
-
def test_generate_embeddings_impl(self,
|
| 38 |
"""Test internal embedding generation implementation"""
|
| 39 |
-
# Mock embeddings
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
]
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
query = "Test query"
|
| 46 |
query_emb, doc_embs = _generate_embeddings_impl(query, self.mock_documents)
|
| 47 |
-
|
| 48 |
# Verify embed_content was called correctly
|
| 49 |
-
self.assertEqual(
|
| 50 |
-
|
| 51 |
-
# Check query embedding call
|
| 52 |
-
first_call = mock_embed.call_args_list[0]
|
| 53 |
-
self.assertEqual(first_call[1]['content'], query)
|
| 54 |
-
self.assertEqual(first_call[1]['task_type'], 'retrieval_query')
|
| 55 |
-
|
| 56 |
-
# Check doc embedding call
|
| 57 |
-
second_call = mock_embed.call_args_list[1]
|
| 58 |
-
self.assertEqual(second_call[1]['content'], self.mock_doc.page_content)
|
| 59 |
-
self.assertEqual(second_call[1]['task_type'], 'retrieval_document')
|
| 60 |
-
|
| 61 |
# Verify embeddings
|
| 62 |
self.assertEqual(query_emb, [0.1, 0.2, 0.3])
|
| 63 |
self.assertEqual(len(doc_embs), 1)
|
| 64 |
self.assertEqual(doc_embs[0], [0.2, 0.3, 0.4])
|
| 65 |
-
|
| 66 |
-
@patch(
|
| 67 |
-
def test_generate_embeddings_with_timer(self,
|
| 68 |
"""Test embedding generation with timer"""
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
mock_timer = Mock()
|
| 75 |
mock_timer.time_step = MagicMock()
|
| 76 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 77 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 78 |
-
|
| 79 |
generate_embeddings("Test", self.mock_documents, timer=mock_timer)
|
| 80 |
-
|
| 81 |
# Verify timer was used
|
| 82 |
mock_timer.time_step.assert_called_once_with("embedding_generation")
|
| 83 |
-
|
| 84 |
-
@patch(
|
| 85 |
-
def test_generate_embeddings_multiple_docs(self,
|
| 86 |
"""Test embedding generation with multiple documents"""
|
| 87 |
# Create multiple mock documents
|
| 88 |
mock_doc2 = Mock()
|
| 89 |
mock_doc2.page_content = "Second document"
|
| 90 |
docs = [self.mock_doc, mock_doc2]
|
| 91 |
-
|
| 92 |
# Mock embeddings
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
query_emb, doc_embs = _generate_embeddings_impl("Test", docs)
|
| 100 |
-
|
| 101 |
# Should have 2 doc embeddings
|
| 102 |
self.assertEqual(len(doc_embs), 2)
|
| 103 |
-
self.assertEqual(
|
| 104 |
-
|
| 105 |
def test_calculate_similarity_impl(self):
|
| 106 |
"""Test internal similarity calculation implementation"""
|
| 107 |
query_embedding = [1.0, 0.0, 0.0]
|
| 108 |
doc_embeddings = [
|
| 109 |
[1.0, 0.0, 0.0], # Same as query - score should be ~1.0
|
| 110 |
[0.0, 1.0, 0.0], # Orthogonal - score should be ~0.0
|
| 111 |
-
[0.5, 0.5, 0.0]
|
| 112 |
]
|
| 113 |
-
|
| 114 |
scores = _calculate_similarity_impl(query_embedding, doc_embeddings)
|
| 115 |
-
|
| 116 |
# Check scores
|
| 117 |
self.assertEqual(len(scores), 3)
|
| 118 |
self.assertAlmostEqual(scores[0], 1.0, places=5)
|
| 119 |
self.assertAlmostEqual(scores[1], 0.0, places=5)
|
| 120 |
self.assertGreater(scores[2], 0.0)
|
| 121 |
self.assertLess(scores[2], 1.0)
|
| 122 |
-
|
| 123 |
def test_calculate_similarity_with_timer(self):
|
| 124 |
"""Test similarity calculation with timer"""
|
| 125 |
mock_timer = Mock()
|
| 126 |
mock_timer.time_step = MagicMock()
|
| 127 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 128 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 129 |
-
|
| 130 |
query_emb = [1.0, 0.0, 0.0]
|
| 131 |
doc_embs = [[1.0, 0.0, 0.0]]
|
| 132 |
-
|
| 133 |
calculate_similarity(query_emb, doc_embs, timer=mock_timer)
|
| 134 |
-
|
| 135 |
# Verify timer was used
|
| 136 |
mock_timer.time_step.assert_called_once_with("similarity_calculation")
|
| 137 |
-
|
| 138 |
def test_process_context_impl(self):
|
| 139 |
"""Test internal context processing implementation"""
|
| 140 |
# Create mock results with metadata
|
|
@@ -142,48 +157,48 @@ class TestVectorStore(unittest.TestCase):
|
|
| 142 |
for i in range(3):
|
| 143 |
mock_result = Mock()
|
| 144 |
mock_result.metadata = {
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
}
|
| 149 |
results.append(mock_result)
|
| 150 |
-
|
| 151 |
# Cosine scores (sorted: 0.9, 0.7, 0.5)
|
| 152 |
cosine_scores = [0.7, 0.5, 0.9]
|
| 153 |
-
|
| 154 |
context, source_ids, knowledge_pairs = _process_context_impl(
|
| 155 |
results, cosine_scores, max_results=2
|
| 156 |
)
|
| 157 |
-
|
| 158 |
# Should return top 2 results
|
| 159 |
self.assertEqual(len(source_ids), 2)
|
| 160 |
self.assertEqual(len(knowledge_pairs), 2)
|
| 161 |
-
|
| 162 |
# Check that highest score (0.9, index 2) is first
|
| 163 |
-
self.assertEqual(source_ids[0],
|
| 164 |
-
self.assertEqual(knowledge_pairs[0][0],
|
| 165 |
-
|
| 166 |
# Check formatted context
|
| 167 |
self.assertIn("Knowledge Entry 1:", context)
|
| 168 |
self.assertIn("Knowledge Entry 2:", context)
|
| 169 |
self.assertIn("Q: Question 3?", context)
|
| 170 |
self.assertIn("A: Answer 3.", context)
|
| 171 |
-
|
| 172 |
def test_process_context_with_timer(self):
|
| 173 |
"""Test context processing with timer"""
|
| 174 |
mock_result = Mock()
|
| 175 |
-
mock_result.metadata = {
|
| 176 |
-
|
| 177 |
mock_timer = Mock()
|
| 178 |
mock_timer.time_step = MagicMock()
|
| 179 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 180 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 181 |
-
|
| 182 |
process_context([mock_result], [0.9], timer=mock_timer)
|
| 183 |
-
|
| 184 |
# Verify timer was used
|
| 185 |
mock_timer.time_step.assert_called_once_with("context_processing")
|
| 186 |
-
|
| 187 |
def test_process_context_max_results(self):
|
| 188 |
"""Test that max_results parameter limits output"""
|
| 189 |
# Create 5 mock results
|
|
@@ -191,53 +206,157 @@ class TestVectorStore(unittest.TestCase):
|
|
| 191 |
for i in range(5):
|
| 192 |
mock_result = Mock()
|
| 193 |
mock_result.metadata = {
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
}
|
| 198 |
results.append(mock_result)
|
| 199 |
-
|
| 200 |
scores = [0.9, 0.8, 0.7, 0.6, 0.5]
|
| 201 |
-
|
| 202 |
# Request only 3 results
|
| 203 |
context, source_ids, knowledge_pairs = _process_context_impl(
|
| 204 |
results, scores, max_results=3
|
| 205 |
)
|
| 206 |
-
|
| 207 |
# Should only return 3
|
| 208 |
self.assertEqual(len(source_ids), 3)
|
| 209 |
self.assertEqual(len(knowledge_pairs), 3)
|
| 210 |
-
|
| 211 |
def test_process_context_formatting(self):
|
| 212 |
"""Test context formatting details"""
|
| 213 |
mock_result = Mock()
|
| 214 |
mock_result.metadata = {
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
}
|
| 219 |
-
|
| 220 |
context, _, _ = _process_context_impl([mock_result], [0.9], max_results=1)
|
| 221 |
-
|
| 222 |
# Check formatting
|
| 223 |
self.assertIn("Knowledge Entry 1:", context)
|
| 224 |
self.assertIn("Q: Test question?", context)
|
| 225 |
self.assertIn("A: Test answer.", context)
|
| 226 |
self.assertIn("-" * 40, context)
|
| 227 |
-
|
| 228 |
def test_process_context_missing_metadata(self):
|
| 229 |
"""Test context processing with missing metadata fields"""
|
| 230 |
mock_result = Mock()
|
| 231 |
mock_result.metadata = {} # No metadata
|
| 232 |
-
|
| 233 |
context, source_ids, knowledge_pairs = _process_context_impl(
|
| 234 |
[mock_result], [0.9], max_results=1
|
| 235 |
)
|
| 236 |
-
|
| 237 |
# Should handle missing fields with N/A
|
| 238 |
self.assertIn("N/A", context)
|
| 239 |
self.assertEqual(source_ids[0], "N/A")
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
-
if __name__ ==
|
| 243 |
unittest.main()
|
|
|
|
| 2 |
Unit tests for vector_store module
|
| 3 |
Tests ChromaDB vector store operations
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import unittest
|
| 7 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 8 |
+
|
| 9 |
+
from src.vector_store import (_calculate_similarity_impl,
|
| 10 |
+
_generate_embeddings_impl, _process_context_impl,
|
| 11 |
+
calculate_similarity, generate_embeddings,
|
| 12 |
+
process_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class TestVectorStore(unittest.TestCase):
|
| 16 |
"""Test cases for vector_store module"""
|
| 17 |
+
|
| 18 |
def setUp(self):
|
| 19 |
"""Set up test fixtures"""
|
| 20 |
# Mock document
|
| 21 |
self.mock_doc = Mock()
|
| 22 |
self.mock_doc.page_content = "Test document content"
|
| 23 |
self.mock_doc.metadata = {
|
| 24 |
+
"id": "KB001",
|
| 25 |
+
"question": "Test question?",
|
| 26 |
+
"content": "Test answer.",
|
| 27 |
+
"section": "Test",
|
| 28 |
}
|
| 29 |
+
|
| 30 |
self.mock_documents = [self.mock_doc]
|
| 31 |
+
|
| 32 |
+
@patch("src.vector_store.genai_client")
|
| 33 |
+
def test_generate_embeddings_impl(self, mock_genai_client):
|
| 34 |
"""Test internal embedding generation implementation"""
|
| 35 |
+
# Mock embeddings for query and document
|
| 36 |
+
mock_query_embedding = Mock()
|
| 37 |
+
mock_query_embedding.values = [0.1, 0.2, 0.3]
|
| 38 |
+
mock_doc_embedding = Mock()
|
| 39 |
+
mock_doc_embedding.values = [0.2, 0.3, 0.4]
|
| 40 |
+
|
| 41 |
+
# Setup side effect for multiple calls
|
| 42 |
+
call_count = [0]
|
| 43 |
+
def embed_side_effect(*args, **kwargs):
|
| 44 |
+
call_count[0] += 1
|
| 45 |
+
mock_response = Mock()
|
| 46 |
+
if call_count[0] == 1:
|
| 47 |
+
mock_response.embeddings = [mock_query_embedding]
|
| 48 |
+
else:
|
| 49 |
+
mock_response.embeddings = [mock_doc_embedding]
|
| 50 |
+
return mock_response
|
| 51 |
+
|
| 52 |
+
mock_genai_client.models.embed_content.side_effect = embed_side_effect
|
| 53 |
+
|
| 54 |
query = "Test query"
|
| 55 |
query_emb, doc_embs = _generate_embeddings_impl(query, self.mock_documents)
|
| 56 |
+
|
| 57 |
# Verify embed_content was called correctly
|
| 58 |
+
self.assertEqual(mock_genai_client.models.embed_content.call_count, 2)
|
| 59 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# Verify embeddings
|
| 61 |
self.assertEqual(query_emb, [0.1, 0.2, 0.3])
|
| 62 |
self.assertEqual(len(doc_embs), 1)
|
| 63 |
self.assertEqual(doc_embs[0], [0.2, 0.3, 0.4])
|
| 64 |
+
|
| 65 |
+
@patch("src.vector_store.genai_client")
|
| 66 |
+
def test_generate_embeddings_with_timer(self, mock_genai_client):
|
| 67 |
"""Test embedding generation with timer"""
|
| 68 |
+
# Mock embeddings
|
| 69 |
+
mock_embedding = Mock()
|
| 70 |
+
mock_embedding.values = [0.1, 0.2, 0.3]
|
| 71 |
+
mock_response = Mock()
|
| 72 |
+
mock_response.embeddings = [mock_embedding]
|
| 73 |
+
mock_genai_client.models.embed_content.return_value = mock_response
|
| 74 |
+
|
| 75 |
mock_timer = Mock()
|
| 76 |
mock_timer.time_step = MagicMock()
|
| 77 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 78 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 79 |
+
|
| 80 |
generate_embeddings("Test", self.mock_documents, timer=mock_timer)
|
| 81 |
+
|
| 82 |
# Verify timer was used
|
| 83 |
mock_timer.time_step.assert_called_once_with("embedding_generation")
|
| 84 |
+
|
| 85 |
+
@patch("src.vector_store.genai_client")
|
| 86 |
+
def test_generate_embeddings_multiple_docs(self, mock_genai_client):
|
| 87 |
"""Test embedding generation with multiple documents"""
|
| 88 |
# Create multiple mock documents
|
| 89 |
mock_doc2 = Mock()
|
| 90 |
mock_doc2.page_content = "Second document"
|
| 91 |
docs = [self.mock_doc, mock_doc2]
|
| 92 |
+
|
| 93 |
# Mock embeddings
|
| 94 |
+
mock_query_emb = Mock()
|
| 95 |
+
mock_query_emb.values = [0.1, 0.2, 0.3]
|
| 96 |
+
mock_doc1_emb = Mock()
|
| 97 |
+
mock_doc1_emb.values = [0.2, 0.3, 0.4]
|
| 98 |
+
mock_doc2_emb = Mock()
|
| 99 |
+
mock_doc2_emb.values = [0.3, 0.4, 0.5]
|
| 100 |
+
|
| 101 |
+
# First call for query, second call for both docs
|
| 102 |
+
call_count = [0]
|
| 103 |
+
def embed_side_effect(*args, **kwargs):
|
| 104 |
+
call_count[0] += 1
|
| 105 |
+
mock_response = Mock()
|
| 106 |
+
if call_count[0] == 1:
|
| 107 |
+
mock_response.embeddings = [mock_query_emb]
|
| 108 |
+
else:
|
| 109 |
+
mock_response.embeddings = [mock_doc1_emb, mock_doc2_emb]
|
| 110 |
+
return mock_response
|
| 111 |
+
|
| 112 |
+
mock_genai_client.models.embed_content.side_effect = embed_side_effect
|
| 113 |
+
|
| 114 |
query_emb, doc_embs = _generate_embeddings_impl("Test", docs)
|
| 115 |
+
|
| 116 |
# Should have 2 doc embeddings
|
| 117 |
self.assertEqual(len(doc_embs), 2)
|
| 118 |
+
self.assertEqual(mock_genai_client.models.embed_content.call_count, 2)
|
| 119 |
+
|
| 120 |
def test_calculate_similarity_impl(self):
|
| 121 |
"""Test internal similarity calculation implementation"""
|
| 122 |
query_embedding = [1.0, 0.0, 0.0]
|
| 123 |
doc_embeddings = [
|
| 124 |
[1.0, 0.0, 0.0], # Same as query - score should be ~1.0
|
| 125 |
[0.0, 1.0, 0.0], # Orthogonal - score should be ~0.0
|
| 126 |
+
[0.5, 0.5, 0.0], # Partial similarity
|
| 127 |
]
|
| 128 |
+
|
| 129 |
scores = _calculate_similarity_impl(query_embedding, doc_embeddings)
|
| 130 |
+
|
| 131 |
# Check scores
|
| 132 |
self.assertEqual(len(scores), 3)
|
| 133 |
self.assertAlmostEqual(scores[0], 1.0, places=5)
|
| 134 |
self.assertAlmostEqual(scores[1], 0.0, places=5)
|
| 135 |
self.assertGreater(scores[2], 0.0)
|
| 136 |
self.assertLess(scores[2], 1.0)
|
| 137 |
+
|
| 138 |
def test_calculate_similarity_with_timer(self):
|
| 139 |
"""Test similarity calculation with timer"""
|
| 140 |
mock_timer = Mock()
|
| 141 |
mock_timer.time_step = MagicMock()
|
| 142 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 143 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 144 |
+
|
| 145 |
query_emb = [1.0, 0.0, 0.0]
|
| 146 |
doc_embs = [[1.0, 0.0, 0.0]]
|
| 147 |
+
|
| 148 |
calculate_similarity(query_emb, doc_embs, timer=mock_timer)
|
| 149 |
+
|
| 150 |
# Verify timer was used
|
| 151 |
mock_timer.time_step.assert_called_once_with("similarity_calculation")
|
| 152 |
+
|
| 153 |
def test_process_context_impl(self):
|
| 154 |
"""Test internal context processing implementation"""
|
| 155 |
# Create mock results with metadata
|
|
|
|
| 157 |
for i in range(3):
|
| 158 |
mock_result = Mock()
|
| 159 |
mock_result.metadata = {
|
| 160 |
+
"id": f"KB00{i+1}",
|
| 161 |
+
"question": f"Question {i+1}?",
|
| 162 |
+
"content": f"Answer {i+1}.",
|
| 163 |
}
|
| 164 |
results.append(mock_result)
|
| 165 |
+
|
| 166 |
# Cosine scores (sorted: 0.9, 0.7, 0.5)
|
| 167 |
cosine_scores = [0.7, 0.5, 0.9]
|
| 168 |
+
|
| 169 |
context, source_ids, knowledge_pairs = _process_context_impl(
|
| 170 |
results, cosine_scores, max_results=2
|
| 171 |
)
|
| 172 |
+
|
| 173 |
# Should return top 2 results
|
| 174 |
self.assertEqual(len(source_ids), 2)
|
| 175 |
self.assertEqual(len(knowledge_pairs), 2)
|
| 176 |
+
|
| 177 |
# Check that highest score (0.9, index 2) is first
|
| 178 |
+
self.assertEqual(source_ids[0], "KB003")
|
| 179 |
+
self.assertEqual(knowledge_pairs[0][0], "Question 3?")
|
| 180 |
+
|
| 181 |
# Check formatted context
|
| 182 |
self.assertIn("Knowledge Entry 1:", context)
|
| 183 |
self.assertIn("Knowledge Entry 2:", context)
|
| 184 |
self.assertIn("Q: Question 3?", context)
|
| 185 |
self.assertIn("A: Answer 3.", context)
|
| 186 |
+
|
| 187 |
def test_process_context_with_timer(self):
|
| 188 |
"""Test context processing with timer"""
|
| 189 |
mock_result = Mock()
|
| 190 |
+
mock_result.metadata = {"id": "KB001", "question": "Q?", "content": "A."}
|
| 191 |
+
|
| 192 |
mock_timer = Mock()
|
| 193 |
mock_timer.time_step = MagicMock()
|
| 194 |
mock_timer.time_step.return_value.__enter__ = Mock()
|
| 195 |
mock_timer.time_step.return_value.__exit__ = Mock()
|
| 196 |
+
|
| 197 |
process_context([mock_result], [0.9], timer=mock_timer)
|
| 198 |
+
|
| 199 |
# Verify timer was used
|
| 200 |
mock_timer.time_step.assert_called_once_with("context_processing")
|
| 201 |
+
|
| 202 |
def test_process_context_max_results(self):
|
| 203 |
"""Test that max_results parameter limits output"""
|
| 204 |
# Create 5 mock results
|
|
|
|
| 206 |
for i in range(5):
|
| 207 |
mock_result = Mock()
|
| 208 |
mock_result.metadata = {
|
| 209 |
+
"id": f"KB00{i}",
|
| 210 |
+
"question": f"Q{i}?",
|
| 211 |
+
"content": f"A{i}.",
|
| 212 |
}
|
| 213 |
results.append(mock_result)
|
| 214 |
+
|
| 215 |
scores = [0.9, 0.8, 0.7, 0.6, 0.5]
|
| 216 |
+
|
| 217 |
# Request only 3 results
|
| 218 |
context, source_ids, knowledge_pairs = _process_context_impl(
|
| 219 |
results, scores, max_results=3
|
| 220 |
)
|
| 221 |
+
|
| 222 |
# Should only return 3
|
| 223 |
self.assertEqual(len(source_ids), 3)
|
| 224 |
self.assertEqual(len(knowledge_pairs), 3)
|
| 225 |
+
|
| 226 |
def test_process_context_formatting(self):
|
| 227 |
"""Test context formatting details"""
|
| 228 |
mock_result = Mock()
|
| 229 |
mock_result.metadata = {
|
| 230 |
+
"id": "KB001",
|
| 231 |
+
"question": "Test question?",
|
| 232 |
+
"content": "Test answer.",
|
| 233 |
}
|
| 234 |
+
|
| 235 |
context, _, _ = _process_context_impl([mock_result], [0.9], max_results=1)
|
| 236 |
+
|
| 237 |
# Check formatting
|
| 238 |
self.assertIn("Knowledge Entry 1:", context)
|
| 239 |
self.assertIn("Q: Test question?", context)
|
| 240 |
self.assertIn("A: Test answer.", context)
|
| 241 |
self.assertIn("-" * 40, context)
|
| 242 |
+
|
| 243 |
def test_process_context_missing_metadata(self):
|
| 244 |
"""Test context processing with missing metadata fields"""
|
| 245 |
mock_result = Mock()
|
| 246 |
mock_result.metadata = {} # No metadata
|
| 247 |
+
|
| 248 |
context, source_ids, knowledge_pairs = _process_context_impl(
|
| 249 |
[mock_result], [0.9], max_results=1
|
| 250 |
)
|
| 251 |
+
|
| 252 |
# Should handle missing fields with N/A
|
| 253 |
self.assertIn("N/A", context)
|
| 254 |
self.assertEqual(source_ids[0], "N/A")
|
| 255 |
|
| 256 |
+
@patch("src.vector_store.get_knowledge_base_data")
|
| 257 |
+
@patch("src.vector_store.chromadb.PersistentClient")
|
| 258 |
+
@patch("src.vector_store.Chroma")
|
| 259 |
+
def test_initialize_vector_store_new_collection(
|
| 260 |
+
self, mock_chroma_class, mock_client_class, mock_get_kb
|
| 261 |
+
):
|
| 262 |
+
"""Test initializing vector store with new collection"""
|
| 263 |
+
# Mock knowledge base data
|
| 264 |
+
mock_get_kb.return_value = (
|
| 265 |
+
["doc1", "doc2"],
|
| 266 |
+
[{"id": "1"}, {"id": "2"}],
|
| 267 |
+
["id1", "id2"],
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Mock ChromaDB client
|
| 271 |
+
mock_client = Mock()
|
| 272 |
+
mock_client_class.return_value = mock_client
|
| 273 |
+
|
| 274 |
+
# Simulate collection doesn't exist (raises exception)
|
| 275 |
+
mock_client.get_collection.side_effect = Exception("Collection not found")
|
| 276 |
+
|
| 277 |
+
# Mock create_collection
|
| 278 |
+
mock_collection = Mock()
|
| 279 |
+
mock_client.create_collection.return_value = mock_collection
|
| 280 |
+
|
| 281 |
+
# Mock Chroma vector store
|
| 282 |
+
mock_vector_store = Mock()
|
| 283 |
+
mock_retriever = Mock()
|
| 284 |
+
mock_vector_store.as_retriever.return_value = mock_retriever
|
| 285 |
+
mock_chroma_class.return_value = mock_vector_store
|
| 286 |
+
|
| 287 |
+
# Call function
|
| 288 |
+
from src.vector_store import initialize_vector_store
|
| 289 |
+
|
| 290 |
+
collection, vector_store, retriever = initialize_vector_store()
|
| 291 |
+
|
| 292 |
+
# Verify collection was created
|
| 293 |
+
mock_client.create_collection.assert_called_once()
|
| 294 |
+
mock_collection.add.assert_called_once()
|
| 295 |
+
|
| 296 |
+
# Verify vector store and retriever
|
| 297 |
+
self.assertEqual(vector_store, mock_vector_store)
|
| 298 |
+
self.assertEqual(retriever, mock_retriever)
|
| 299 |
+
|
| 300 |
+
@patch("src.vector_store.get_knowledge_base_data")
|
| 301 |
+
@patch("src.vector_store.chromadb.PersistentClient")
|
| 302 |
+
@patch("src.vector_store.Chroma")
|
| 303 |
+
def test_initialize_vector_store_existing_collection(
|
| 304 |
+
self, mock_chroma_class, mock_client_class, mock_get_kb
|
| 305 |
+
):
|
| 306 |
+
"""Test initializing vector store with existing collection"""
|
| 307 |
+
# Mock knowledge base data
|
| 308 |
+
mock_get_kb.return_value = (
|
| 309 |
+
["doc1", "doc2"],
|
| 310 |
+
[{"id": "1"}, {"id": "2"}],
|
| 311 |
+
["id1", "id2"],
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Mock ChromaDB client
|
| 315 |
+
mock_client = Mock()
|
| 316 |
+
mock_client_class.return_value = mock_client
|
| 317 |
+
|
| 318 |
+
# Simulate collection exists
|
| 319 |
+
mock_collection = Mock()
|
| 320 |
+
mock_client.get_collection.return_value = mock_collection
|
| 321 |
+
|
| 322 |
+
# Mock Chroma vector store
|
| 323 |
+
mock_vector_store = Mock()
|
| 324 |
+
mock_retriever = Mock()
|
| 325 |
+
mock_vector_store.as_retriever.return_value = mock_retriever
|
| 326 |
+
mock_chroma_class.return_value = mock_vector_store
|
| 327 |
+
|
| 328 |
+
# Call function
|
| 329 |
+
from src.vector_store import initialize_vector_store
|
| 330 |
+
|
| 331 |
+
collection, vector_store, retriever = initialize_vector_store()
|
| 332 |
+
|
| 333 |
+
# Verify existing collection was loaded (not created)
|
| 334 |
+
mock_client.get_collection.assert_called_once()
|
| 335 |
+
mock_client.create_collection.assert_not_called()
|
| 336 |
+
|
| 337 |
+
# Verify vector store and retriever
|
| 338 |
+
self.assertEqual(collection, mock_collection)
|
| 339 |
+
self.assertEqual(vector_store, mock_vector_store)
|
| 340 |
+
self.assertEqual(retriever, mock_retriever)
|
| 341 |
+
|
| 342 |
+
@patch("src.vector_store.get_knowledge_base_data")
|
| 343 |
+
@patch("src.vector_store.chromadb.PersistentClient")
|
| 344 |
+
def test_initialize_vector_store_failure(self, mock_client_class, mock_get_kb):
|
| 345 |
+
"""Test initialize_vector_store handles errors properly"""
|
| 346 |
+
# Mock knowledge base data
|
| 347 |
+
mock_get_kb.return_value = (["doc1"], [{"id": "1"}], ["id1"])
|
| 348 |
+
|
| 349 |
+
# Mock client to raise exception
|
| 350 |
+
mock_client_class.side_effect = Exception("Database connection failed")
|
| 351 |
+
|
| 352 |
+
# Call function and expect exception
|
| 353 |
+
from src.vector_store import initialize_vector_store
|
| 354 |
+
|
| 355 |
+
with self.assertRaises(Exception) as context:
|
| 356 |
+
initialize_vector_store()
|
| 357 |
+
|
| 358 |
+
self.assertIn("Database connection failed", str(context.exception))
|
| 359 |
+
|
| 360 |
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
unittest.main()
|
tox.ini
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tox]
|
| 2 |
+
envlist = py310,py311,format,lint
|
| 3 |
+
skipsdist = True
|
| 4 |
+
|
| 5 |
+
[testenv]
|
| 6 |
+
deps = -r requirements.txt
|
| 7 |
+
commands = pytest {posargs}
|
| 8 |
+
|
| 9 |
+
[testenv:format]
|
| 10 |
+
deps =
|
| 11 |
+
black
|
| 12 |
+
isort
|
| 13 |
+
autoflake
|
| 14 |
+
commands =
|
| 15 |
+
autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive src tests app.py
|
| 16 |
+
black src tests app.py
|
| 17 |
+
isort src tests app.py
|
| 18 |
+
|
| 19 |
+
[testenv:lint]
|
| 20 |
+
deps =
|
| 21 |
+
flake8
|
| 22 |
+
pylint
|
| 23 |
+
commands =
|
| 24 |
+
flake8 src tests
|
| 25 |
+
pylint src tests
|