github-actions commited on
Commit
3cdce90
·
1 Parent(s): 0aa781d

Sync from GitHub

Browse files
.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ GOOGLE_SHEETS_CREDENTIALS = 'google_sheets_credentials.json'
2
+ GEMINI_API_KEY = "your_gemini_api_key_here"
.github/workflows/tests.yml CHANGED
@@ -2,9 +2,9 @@ name: Run Tests
2
 
3
  on:
4
  push:
5
- branches: [ main, develop ]
6
  pull_request:
7
- branches: [ main, develop ]
8
 
9
  jobs:
10
  test:
@@ -12,7 +12,7 @@ jobs:
12
 
13
  strategy:
14
  matrix:
15
- python-version: ['3.11']
16
 
17
  steps:
18
  - uses: actions/checkout@v3
 
2
 
3
  on:
4
  push:
5
+ branches: [ main, development ]
6
  pull_request:
7
+ branches: [ main, development ]
8
 
9
  jobs:
10
  test:
 
12
 
13
  strategy:
14
  matrix:
15
+ python-version: ['3.13']
16
 
17
  steps:
18
  - uses: actions/checkout@v3
XENO%20Uganda_KnowlegeBase_V1.json DELETED
The diff for this file is too large to render. See raw diff
 
app.py CHANGED
@@ -2,46 +2,23 @@
2
  XENO Bot - AI-powered customer service assistant
3
  Main application file with Gradio interface
4
  """
5
- import os
6
- import uuid
7
- import gradio as gr
8
- import pandas as pd
9
- import torch
10
- import numpy as np
11
- from sentence_transformers import util
12
- from google import genai
13
- import chromadb
14
- from langchain_chroma import Chroma
15
- import gspread
16
- from google.oauth2.service_account import Credentials
17
- from langgraph.checkpoint.sqlite import SqliteSaver
18
- import sqlite3
19
- import json
20
- from datetime import datetime
21
- import re
22
- from typing import Dict, List, Tuple
23
- import time
24
- from contextlib import contextmanager
25
- import threading # <--- Added for non-blocking feedback logging
26
  import logging
 
27
  import traceback
28
 
29
- # Import custom modules
30
- from src.utils import PipelineTimer
31
- from src.config import SIMILARITY_THRESHOLD, SERVER_NAME, SERVER_PORT
32
- from src.memory import create_session_config, update_memory, retrieve_memory
33
  from src.intent_classifier import IntentClassifier
34
- from src.vector_store import (
35
- initialize_vector_store,
36
- generate_embeddings,
37
- calculate_similarity,
38
- process_context
39
- )
40
- from src.response_generator import generate_xeno_response
41
  from src.logger import log_response, log_timing_data
42
-
43
- # Initialize components
44
- timer = PipelineTimer()
 
 
 
45
 
46
  # === Configuration ===
47
  # Ensure API Key is set
@@ -49,351 +26,61 @@ if "GEMINI_API_KEY" not in os.environ:
49
  print("WARNING: GEMINI_API_KEY environment variable not found.")
50
 
51
  # Initialize the client
52
- genai_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
53
- embedding_model = "models/embedding-001"
54
- llm_model_name = "models/gemma-3-4b-it"
55
- collection_name = "xeno_collection"
56
-
57
- # === Google Sheets Setup ===
58
- def get_google_sheets_credentials():
59
- credentials_json = os.environ.get("GOOGLE_SHEETS_CREDENTIALS")
60
- if not credentials_json:
61
- raise ValueError("GOOGLE_SHEETS_CREDENTIALS environment variable not set.")
62
- credentials_dict = json.loads(credentials_json)
63
- scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
64
- creds = Credentials.from_service_account_info(credentials_dict, scopes=scope)
65
- return creds
66
-
67
- # Authenticate
68
- try:
69
- client_gspread = gspread.authorize(get_google_sheets_credentials())
70
- spreadsheet = client_gspread.open("Response_Log")
71
- response_sheet = spreadsheet.sheet1
72
- except Exception as e:
73
- print(f"Error connecting to Google Sheets: {e}")
74
- # Create dummy objects if connection fails to prevent app crash during dev
75
- class DummySheet:
76
- def append_row(self, *args, **kwargs): pass
77
- def worksheet(self, *args): return self
78
- def add_worksheet(self, *args, **kwargs): return self
79
- spreadsheet = DummySheet()
80
- response_sheet = DummySheet()
81
-
82
- # Setup Timing Sheet
83
- try:
84
- timing_sheet = spreadsheet.worksheet("Timing_Log")
85
- except:
86
- try:
87
- timing_sheet = spreadsheet.add_worksheet(title="Timing_Log", rows="1000", cols="15")
88
- headers = [
89
- "Timestamp", "Session_ID", "Question", "Total_Time_MS",
90
- "Intent_Classification_MS", "Memory_Retrieval_MS", "RAG_Retrieval_MS",
91
- "Embedding_Generation_MS", "Similarity_Calculation_MS", "Context_Processing_MS",
92
- "LLM_Generation_MS", "Memory_Update_MS", "Logging_MS", "Error_Step", "Notes"
93
- ]
94
- timing_sheet.append_row(headers)
95
- except Exception as e:
96
- print(f"Could not create Timing_Log sheet: {e}")
97
- timing_sheet = None
98
-
99
- # === NEW: Setup Feedback Sheet ===
100
- try:
101
- feedback_sheet = spreadsheet.worksheet("Feedback_Log")
102
- except:
103
- try:
104
- feedback_sheet = spreadsheet.add_worksheet(title="Feedback_Log", rows="1000", cols="6")
105
- headers = ["Timestamp", "Session_ID", "User_Message", "Bot_Response", "Rating", "Flag_Reason"]
106
- feedback_sheet.append_row(headers)
107
- except Exception as e:
108
- print(f"Could not create Feedback_Log sheet: {e}")
109
- feedback_sheet = None
110
-
111
- # === Logging Functions ===
112
-
113
- def log_response(question, answer, source_ids, knowledge_pairs, session_id):
114
- """Original response logging function"""
115
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
116
- knowledge_question_1 = knowledge_pairs[0][0] if len(knowledge_pairs) > 0 else "N/A"
117
- knowledge_answer_1 = knowledge_pairs[0][1] if len(knowledge_pairs) > 0 else "N/A"
118
- knowledge_question_2 = knowledge_pairs[1][0] if len(knowledge_pairs) > 1 else "N/A"
119
- knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
120
- row = [
121
- timestamp, session_id, question, answer, source_ids,
122
- knowledge_question_1, knowledge_answer_1, knowledge_question_2, knowledge_answer_2
123
- ]
124
- try:
125
- response_sheet.append_row(row)
126
- print(f"Logged response: {question} | Source IDs: {source_ids}")
127
- except Exception as e:
128
- print(f"Failed to log to Google Sheet: {e}")
129
- with open("/tmp/response_log.txt", "a") as f:
130
- f.write(f"{timestamp},{question},{answer},{source_ids}\n")
131
-
132
- def log_timing_data(question, session_id, timing_summary, error_step=None, notes=None):
133
- """Log timing data to the timing sheet"""
134
- if timing_sheet is None: return
135
-
136
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
137
- step_times = timing_summary['step_times']
138
-
139
- row = [
140
- timestamp,
141
- session_id,
142
- question[:100] + "..." if len(question) > 100 else question,
143
- timing_summary['total_time_ms'],
144
- step_times.get('intent_classification', 0),
145
- step_times.get('memory_retrieval', 0),
146
- step_times.get('rag_retrieval', 0),
147
- step_times.get('embedding_generation', 0),
148
- step_times.get('similarity_calculation', 0),
149
- step_times.get('context_processing', 0),
150
- step_times.get('llm_generation', 0),
151
- step_times.get('memory_update', 0),
152
- step_times.get('response_logging', 0),
153
- error_step or "",
154
- notes or ""
155
- ]
156
-
157
- try:
158
- timing_sheet.append_row(row)
159
- print(f"Logged timing data: Total {timing_summary['total_time_ms']}ms")
160
- except Exception as e:
161
- print(f"Failed to log timing data: {e}")
162
-
163
- # === NEW: Feedback Functions ===
164
-
165
- def _log_feedback_background(row):
166
- """Helper to run network request in background thread"""
167
- try:
168
- if feedback_sheet:
169
- feedback_sheet.append_row(row)
170
- print("Feedback logged successfully.")
171
- else:
172
- print("Feedback sheet not available.")
173
- except Exception as e:
174
- print(f"Failed to log feedback: {e}")
175
-
176
- def submit_feedback(rating, reason, history, session_id):
177
- """
178
- Handles user feedback submission.
179
- rating: 'Positive' or 'Negative'
180
- reason: User provided text
181
- history: Gradio chat history list
182
- """
183
- if not history or len(history) == 0:
184
- return "No conversation to rate yet."
185
-
186
- # Get the last interaction (Gradio history is a list of lists: [[user, bot], ...])
187
- last_interaction = history[-1]
188
-
189
- # Safety check for history format
190
- if isinstance(last_interaction, list) and len(last_interaction) >= 2:
191
- user_msg = last_interaction[0]
192
- bot_msg = last_interaction[1]
193
- else:
194
- return "Error reading conversation history."
195
-
196
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
197
-
198
- # Prepare row data
199
- row = [timestamp, session_id, user_msg, bot_msg, rating, reason]
200
-
201
- # Run in thread to prevent UI blocking
202
- threading.Thread(target=_log_feedback_background, args=(row,)).start()
203
-
204
- return f"Feedback received ({rating}). Thank you!"
205
-
206
- # === LangGraph Memory Setup ===
207
- conn = sqlite3.connect("xeno_memory.db", check_same_thread=False)
208
- memory = SqliteSaver(conn=conn)
209
-
210
- def update_memory(config, user_message, assistant_message):
211
- with timer.time_step("memory_update"):
212
- full_checkpoint = memory.get(config) or {}
213
- messages = full_checkpoint.get("channel_values", {}).get("messages", [])
214
-
215
- messages.append({"role": "user", "content": user_message})
216
- messages.append({"role": "assistant", "content": assistant_message})
217
-
218
- checkpoint_to_save = {
219
- "v": 1,
220
- "id": str(uuid.uuid4()),
221
- "ts": datetime.now().isoformat(),
222
- "channel_values": {"messages": messages},
223
- "channel_versions": {},
224
- "versions_seen": {},
225
- }
226
-
227
- memory.put(config, checkpoint_to_save, {}, {})
228
-
229
- def retrieve_memory(config):
230
- with timer.time_step("memory_retrieval"):
231
- full_checkpoint = memory.get(config) or {}
232
- return full_checkpoint.get("channel_values", {}).get("messages", [])
233
 
234
  # === Intent Classification System ===
235
- class IntentClassifier:
236
- def __init__(self):
237
- self.intent_patterns = {
238
- 'greeting': {
239
- 'patterns': [
240
- r'\b(hi|hello|hey|good morning|good afternoon|good evening|greetings)\b',
241
- r'^(hi|hello|hey)[\s!.]*$',
242
- r'\b(how are you|how do you do)\b'
243
- ],
244
- 'responses': [
245
- "Hello! I'm XENO Assistant. How can I help you with XENO financial services today?",
246
- "Hi there! I'm here to assist you with any questions about XENO services. What can I help you with?",
247
- "Good day! Welcome to XENO Support. How may I assist you today?"
248
- ]
249
- },
250
- 'thanks': {
251
- 'patterns': [
252
- r'\b(thank you|thanks|thank u|thx|appreciate|grateful)\b',
253
- r'^(thanks|thank you)[\s!.]*$',
254
- r'\b(much appreciated|thanks a lot|thank you so much)\b'
255
- ],
256
- 'responses': [
257
- "You're welcome! Is there anything else I can help you with regarding XENO services?",
258
- "Happy to help! Feel free to ask if you have any other questions about XENO.",
259
- "Glad I could assist you! Let me know if you need help with anything else."
260
- ]
261
- },
262
- 'goodbye': {
263
- 'patterns': [
264
- r'\b(bye|goodbye|see you|farewell|take care|have a good day)\b',
265
- r'^(bye|goodbye)[\s!.]*$',
266
- r'\b(talk to you later|see you later|until next time)\b'
267
- ],
268
- 'responses': [
269
- "Goodbye! Thank you for using XENO services. Have a great day!",
270
- "Take care! Feel free to return anytime you need help with XENO services.",
271
- "Have a wonderful day! Don't hesitate to reach out if you need assistance with XENO."
272
- ]
273
- }
274
- }
275
-
276
- def classify_intent(self, message: str) -> Tuple[str, str]:
277
- message_lower = message.lower().strip()
278
- for intent_name, intent_data in self.intent_patterns.items():
279
- for pattern in intent_data['patterns']:
280
- if re.search(pattern, message_lower, re.IGNORECASE):
281
- import random
282
- response = random.choice(intent_data['responses'])
283
- return intent_name, response
284
- return 'query', ''
285
-
286
  intent_classifier = IntentClassifier()
287
 
288
  # === Load and Clean Knowledge Base ===
289
- try:
290
- df_kb = pd.read_json("XENO_Uganda_KnowledgeBase_Advisory.json")
291
- df_kb.dropna(subset=['Content'], inplace=True)
292
-
293
- def prepare_documents(data):
294
- documents, metadatas, ids = [], [], []
295
- for item in data:
296
- documents.append(f"Question: {item['Question']}\nAnswer: {item['Content']}")
297
- metadatas.append({
298
- "question": item["Question"],
299
- "content": item["Content"],
300
- "id": str(item["ID"])
301
- })
302
- ids.append(str(item["ID"]))
303
- return documents, metadatas, ids
304
-
305
- xeno_data_list = df_kb.to_dict('records')
306
- documents, metadatas, ids = prepare_documents(xeno_data_list)
307
- except Exception as e:
308
- print(f"Warning: Could not load JSON knowledge base: {e}")
309
- documents, metadatas, ids = [], [], []
310
 
311
  # === Setup ChromaDB ===
312
- try:
313
- client = chromadb.PersistentClient(path="/tmp/xeno_db")
314
- try:
315
- collection = client.get_collection(name=collection_name)
316
- print(f"Loaded existing ChromaDB collection: {collection_name}")
317
- except:
318
- print(f"Creating new ChromaDB collection: {collection_name}")
319
- collection = client.create_collection(name=collection_name)
320
- if documents:
321
- collection.add(documents=documents, metadatas=metadatas, ids=ids)
322
- except Exception as e:
323
- print(f"Failed to initialize ChromaDB: {e}")
324
- raise
325
 
326
- vector_store = Chroma(client=client, collection_name=collection_name)
327
- retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4})
328
-
329
- # === Prompt System ===
330
- SYSTEM_PROMPT = """You are a friendly XENO Support Assistant, an AI-powered helpful and professional customer service representative.
331
- Use only the information provided in the knowledge base context to answer user queries.
332
- Do not hallucinate. If context doesn't contain relevant info, say so in a calm polite manner by saying I'm sorry, I can't assist with that.
333
- Only use context that is clearly relevant to the user's question.
334
- For greetings like "hi" or "hello", respond politely without using the context.
335
- remember previous conversations."""
336
-
337
- # === Context Processing ===
338
- def process_context(results, cosine_scores, max_results=2):
339
- with timer.time_step("context_processing"):
340
- sorted_indices = np.argsort(cosine_scores)[::-1][:max_results]
341
- formatted_context = ""
342
- source_ids = []
343
- knowledge_pairs = []
344
- for i, idx in enumerate(sorted_indices, 1):
345
- result = results[idx]
346
- score = cosine_scores[idx]
347
- question = result.metadata.get('question', 'N/A')
348
- answer = result.metadata.get('content', 'N/A')
349
- formatted_context += f"Knowledge Entry {i}:\n"
350
- formatted_context += f"Q: {question}\n"
351
- formatted_context += f"A: {answer}\n"
352
- formatted_context += "-" * 40 + "\n"
353
- source_ids.append(str(result.metadata.get('id', 'N/A')))
354
- knowledge_pairs.append((question, answer))
355
- return formatted_context, source_ids, knowledge_pairs
356
-
357
- # === LLM Generation ===
358
- def generate_xeno_response(context, question, chat_history):
359
- with timer.time_step("llm_generation"):
360
- formatted_history = "\n".join(
361
- [f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history]
362
- ) if chat_history else "None"
363
-
364
- prompt = f"{SYSTEM_PROMPT}\n### HISTORY ###\n{formatted_history}\n### CONTEXT ###\n{context}\n### QUESTION ###\n{question}"
365
-
366
- response = genai_client.models.generate_content(
367
- model=llm_model_name,
368
- contents={"text": prompt},
369
- )
370
- return response.text.strip()
371
 
372
- # === Main Interface Logic ===
373
- def get_context_and_answer(message, history, session_id="default"):
374
- # Reset timer for new request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  timer.reset()
376
  error_step = None
377
  notes = []
378
-
379
  try:
380
- # Create session config
381
- config = create_session_config(session_id)
382
-
383
  # Step 1: Intent Classification
384
  intent, direct_response = intent_classifier.classify_intent(message)
385
-
386
  # Step 2: Memory Retrieval
387
- chat_history = retrieve_memory(config)
388
-
389
  answer = ""
390
  source_ids = "N/A"
391
  knowledge_pairs = []
392
 
393
- if intent != 'query':
394
  answer = direct_response
395
  notes.append(f"Simple intent: {intent}")
396
- else:
397
  if len(message.strip()) < 3:
398
  answer = "I'd be happy to help! Could you please provide more details about what you'd like to know?"
399
  notes.append("Message too short")
@@ -402,17 +89,19 @@ def get_context_and_answer(message, history, session_id="default"):
402
  # Step 3: RAG Retrieval
403
  with timer.time_step("rag_retrieval"):
404
  queried_results = retriever.invoke(message)
405
-
406
  # Step 4: Embedding Generation
407
  query_embedding, doc_embeddings = generate_embeddings(
408
  message, queried_results, timer
409
  )
410
-
411
  # Step 5: Similarity Calculation
412
  with timer.time_step("similarity_calculation"):
 
 
413
  cosine_scores = util.cos_sim(
414
- torch.tensor(query_embedding).float(),
415
- torch.tensor(doc_embeddings).float()
416
  )[0].tolist()
417
  max_score = max(cosine_scores) if cosine_scores else 0
418
 
@@ -421,8 +110,10 @@ def get_context_and_answer(message, history, session_id="default"):
421
  notes.append(f"Low similarity score: {max_score:.3f}")
422
  else:
423
  # Step 6: Context Processing
424
- context, source_ids_list, knowledge_pairs = process_context(queried_results, cosine_scores)
425
-
 
 
426
  # Step 7: LLM Generation
427
  answer = generate_xeno_response(context, message, chat_history)
428
  source_ids = ", ".join(source_ids_list)
@@ -436,126 +127,44 @@ def get_context_and_answer(message, history, session_id="default"):
436
  notes.append(f"Error: {str(e)}")
437
 
438
  # Step 8: Memory Update
439
- update_memory(config, message, answer)
440
-
441
  # Step 9: Response Logging
442
  log_response(message, answer, source_ids, knowledge_pairs, session_id)
443
-
444
  # Log timing data
445
  timing_summary = timer.get_timing_summary()
446
  log_timing_data(
447
- message,
448
- session_id,
449
- timing_summary,
450
  error_step=error_step,
451
- notes="; ".join(notes) if notes else None
452
  )
453
-
454
  return answer
455
-
456
  except Exception as e:
457
  error_step = timer.current_step or "main_pipeline"
458
  logging.error(f"Error in main pipeline: {e}")
459
  logging.error(traceback.format_exc())
460
-
461
  timing_summary = timer.get_timing_summary()
462
  log_timing_data(
463
- message,
464
- session_id,
465
- timing_summary,
466
  error_step=error_step,
467
- notes=f"Pipeline error: {str(e)}"
468
  )
469
-
470
- return "I apologize, but I encountered an error processing your request. Please try again."
471
-
472
 
473
- # === Enhanced Gradio UI ===
474
- def respond(message: str, history: List, session_id: str):
475
- """Gradio's main response function"""
476
- if not session_id:
477
- session_id = str(uuid.uuid4())
478
-
479
- bot_response = get_context_and_answer(message, history, session_id)
480
- history.append([message, bot_response])
481
-
482
- return "", history
483
-
484
-
485
- def create_interface():
486
- """Create Gradio interface"""
487
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
488
- gr.Markdown("""
489
- # ASKXENO
490
- **Welcome to XENO AI Support!**
491
-
492
- I can help you with questions about XENO financial services including:
493
- - Account management and setup
494
- - Transaction processes and fees
495
- - Platform features and troubleshooting
496
- - General service information
497
-
498
- *Simply type your question below to get started!*
499
- """)
500
-
501
- # Hidden state for session
502
- session_id_box = gr.Textbox(label="Session ID", value=str(uuid.uuid4()), visible=False)
503
-
504
- chatbot = gr.Chatbot(
505
- label="XENO Assistant",
506
- bubble_full_width=False,
507
- height=450
508
- )
509
-
510
- with gr.Row():
511
- msg = gr.Textbox(
512
- label="Your Message",
513
- placeholder="Type your question here...",
514
- scale=4,
515
- )
516
- send_button = gr.Button("Send", variant="primary", scale=1)
517
-
518
- # ===== FEEDBACK SECTION =====
519
- with gr.Row():
520
- with gr.Accordion("Rate this response / Flag Issue", open=False):
521
- with gr.Row():
522
- thumbs_up = gr.Button("👍 Good Answer")
523
- thumbs_down = gr.Button("👎 Bad / Flag")
524
-
525
- feedback_reason = gr.Textbox(
526
- label="Reason ",
527
- placeholder="E.g., Incorrect fees, hallucination,"
528
- )
529
- feedback_status = gr.Label(value="", label="Status", show_label=False)
530
-
531
- # Feedback Event Listeners
532
- # Logic: If Thumbs Up is clicked, send 'Positive'. If Textbox is empty, reason defaults to "Good".
533
- thumbs_up.click(
534
- fn=lambda h, s, r: submit_feedback("Positive", r if r else "Good", h, s),
535
- inputs=[chatbot, session_id_box, feedback_reason],
536
- outputs=[feedback_status]
537
- )
538
-
539
- # Logic: If Thumbs Down is clicked, send 'Negative' with the content of the textbox.
540
- thumbs_down.click(
541
- fn=lambda r, h, s: submit_feedback("Negative", r, h, s),
542
- inputs=[feedback_reason, chatbot, session_id_box],
543
- outputs=[feedback_status]
544
- )
545
- # =============================
546
 
547
- # Chat Event Listeners
548
- send_button.click(respond, [msg, chatbot, session_id_box], [msg, chatbot])
549
- msg.submit(respond, [msg, chatbot, session_id_box], [msg, chatbot])
550
-
551
- return demo
552
 
 
553
 
554
  if __name__ == "__main__":
555
- iface = create_interface()
556
  iface.launch(
557
- share=False,
558
- server_name=SERVER_NAME,
559
- server_port=SERVER_PORT,
560
- ssr_mode=False
561
- )
 
2
  XENO Bot - AI-powered customer service assistant
3
  Main application file with Gradio interface
4
  """
5
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import logging
7
+ import os
8
  import traceback
9
 
10
+ from src.config import (COLLECTION_NAME, EMBEDDING_MODEL, LLM_MODEL_NAME,
11
+ SERVER_NAME, SERVER_PORT, SIMILARITY_THRESHOLD)
 
 
12
  from src.intent_classifier import IntentClassifier
13
+ from src.interface import create_interface
14
+ from src.knowledge_base import get_knowledge_base_data
 
 
 
 
 
15
  from src.logger import log_response, log_timing_data
16
+ from src.memory import create_session_config, retrieve_memory, update_memory
17
+ from src.response_generator import generate_xeno_response
18
+ # Import custom modules
19
+ from src.utils import PipelineTimer
20
+ from src.vector_store import (generate_embeddings, initialize_vector_store,
21
+ process_context)
22
 
23
  # === Configuration ===
24
  # Ensure API Key is set
 
26
  print("WARNING: GEMINI_API_KEY environment variable not found.")
27
 
28
  # Initialize the client
29
+ embedding_model = EMBEDDING_MODEL
30
+ llm_model_name = LLM_MODEL_NAME
31
+ collection_name = COLLECTION_NAME
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # === Intent Classification System ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  intent_classifier = IntentClassifier()
35
 
36
  # === Load and Clean Knowledge Base ===
37
+ documents, metadatas, ids = get_knowledge_base_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # === Setup ChromaDB ===
40
+ collection, vector_store, retriever = initialize_vector_store()
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # === Core Orchestration Logic ===
44
+ def get_context_and_answer(
45
+ message, history, session_id, intent_classifier, retriever
46
+ ):
47
+ """
48
+ Core orchestration function that handles the RAG pipeline
49
+
50
+ Args:
51
+ message: User's message
52
+ history: Chat history
53
+ session_id: Session identifier
54
+ intent_classifier: IntentClassifier instance
55
+ retriever: Vector store retriever instance
56
+
57
+ Returns:
58
+ Generated answer string
59
+ """
60
+ # Create timer per session
61
+ timer = PipelineTimer()
62
  timer.reset()
63
  error_step = None
64
  notes = []
65
+
66
  try:
67
+ # Create session memory config
68
+ memory_config = create_session_config(session_id)
69
+
70
  # Step 1: Intent Classification
71
  intent, direct_response = intent_classifier.classify_intent(message)
72
+
73
  # Step 2: Memory Retrieval
74
+ chat_history = retrieve_memory(memory_config)
75
+
76
  answer = ""
77
  source_ids = "N/A"
78
  knowledge_pairs = []
79
 
80
+ if intent != "query":
81
  answer = direct_response
82
  notes.append(f"Simple intent: {intent}")
83
+ else:
84
  if len(message.strip()) < 3:
85
  answer = "I'd be happy to help! Could you please provide more details about what you'd like to know?"
86
  notes.append("Message too short")
 
89
  # Step 3: RAG Retrieval
90
  with timer.time_step("rag_retrieval"):
91
  queried_results = retriever.invoke(message)
92
+
93
  # Step 4: Embedding Generation
94
  query_embedding, doc_embeddings = generate_embeddings(
95
  message, queried_results, timer
96
  )
97
+
98
  # Step 5: Similarity Calculation
99
  with timer.time_step("similarity_calculation"):
100
+ import sentence_transformers.util as util
101
+ import torch
102
  cosine_scores = util.cos_sim(
103
+ torch.tensor(query_embedding).float(),
104
+ torch.tensor(doc_embeddings).float(),
105
  )[0].tolist()
106
  max_score = max(cosine_scores) if cosine_scores else 0
107
 
 
110
  notes.append(f"Low similarity score: {max_score:.3f}")
111
  else:
112
  # Step 6: Context Processing
113
+ context, source_ids_list, knowledge_pairs = process_context(
114
+ queried_results, cosine_scores
115
+ )
116
+
117
  # Step 7: LLM Generation
118
  answer = generate_xeno_response(context, message, chat_history)
119
  source_ids = ", ".join(source_ids_list)
 
127
  notes.append(f"Error: {str(e)}")
128
 
129
  # Step 8: Memory Update
130
+ update_memory(memory_config, message, answer)
131
+
132
  # Step 9: Response Logging
133
  log_response(message, answer, source_ids, knowledge_pairs, session_id)
134
+
135
  # Log timing data
136
  timing_summary = timer.get_timing_summary()
137
  log_timing_data(
138
+ message,
139
+ session_id,
140
+ timing_summary,
141
  error_step=error_step,
142
+ notes="; ".join(notes) if notes else None,
143
  )
144
+
145
  return answer
146
+
147
  except Exception as e:
148
  error_step = timer.current_step or "main_pipeline"
149
  logging.error(f"Error in main pipeline: {e}")
150
  logging.error(traceback.format_exc())
151
+
152
  timing_summary = timer.get_timing_summary()
153
  log_timing_data(
154
+ message,
155
+ session_id,
156
+ timing_summary,
157
  error_step=error_step,
158
+ notes=f"Pipeline error: {str(e)}",
159
  )
 
 
 
160
 
161
+ return "I apologize, but I encountered an error processing your request. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
 
 
 
 
 
163
 
164
+ # === Main Interface Logic ===
165
 
166
  if __name__ == "__main__":
167
+ iface = create_interface(intent_classifier, retriever)
168
  iface.launch(
169
+ share=False, server_name=SERVER_NAME, server_port=SERVER_PORT, ssr_mode=False
170
+ )
 
 
 
docker-compose.yml CHANGED
@@ -1,5 +1,6 @@
1
  services:
2
  xeno-bot:
 
3
  build:
4
  context: .
5
  dockerfile: Dockerfile
 
1
  services:
2
  xeno-bot:
3
+ image: rogerzmukiibi/xeno-bot:test_v1
4
  build:
5
  context: .
6
  dockerfile: Dockerfile
requirements.txt CHANGED
@@ -2,7 +2,7 @@ huggingface_hub==0.25.2
2
  gradio
3
  pydantic==2.10.6
4
  pandas
5
- torch==2.3.1
6
  numpy
7
  sentence-transformers
8
  google-genai
 
2
  gradio
3
  pydantic==2.10.6
4
  pandas
5
+ torch>=2.3.1
6
  numpy
7
  sentence-transformers
8
  google-genai
src/config.py CHANGED
@@ -2,7 +2,9 @@
2
  Configuration module for XENO Bot
3
  Handles environment variables and application settings
4
  """
 
5
  import os
 
6
  from google import genai
7
 
8
  # === API Configuration ===
@@ -11,7 +13,7 @@ if not GEMINI_API_KEY:
11
  raise ValueError("GEMINI_API_KEY environment variable not set.")
12
 
13
  # Initialize the genai client
14
- client = genai.Client(api_key=GEMINI_API_KEY)
15
 
16
  # === Model Configuration ===
17
  EMBEDDING_MODEL = "text-embedding-004"
@@ -30,6 +32,7 @@ GOOGLE_SHEETS_CREDENTIALS_ENV = "GOOGLE_SHEETS_CREDENTIALS"
30
  SPREADSHEET_NAME = "Response_Log"
31
  RESPONSE_SHEET_INDEX = 0 # sheet1
32
  TIMING_SHEET_NAME = "Timing_Log"
 
33
 
34
  # === RAG Configuration ===
35
  RAG_TOP_K = 4
 
2
  Configuration module for XENO Bot
3
  Handles environment variables and application settings
4
  """
5
+
6
  import os
7
+
8
  from google import genai
9
 
10
  # === API Configuration ===
 
13
  raise ValueError("GEMINI_API_KEY environment variable not set.")
14
 
15
  # Initialize the genai client
16
+ genai_client = genai.Client(api_key=GEMINI_API_KEY)
17
 
18
  # === Model Configuration ===
19
  EMBEDDING_MODEL = "text-embedding-004"
 
32
  SPREADSHEET_NAME = "Response_Log"
33
  RESPONSE_SHEET_INDEX = 0 # sheet1
34
  TIMING_SHEET_NAME = "Timing_Log"
35
+ FEEDBACK_SHEET_NAME = "Feedback_Log"
36
 
37
  # === RAG Configuration ===
38
  RAG_TOP_K = 4
src/intent_classifier.py CHANGED
@@ -2,62 +2,63 @@
2
  Intent Classification module for XENO Bot
3
  Handles classification of user intents (greetings, thanks, goodbye, queries)
4
  """
5
- import re
6
  import random
7
- from typing import Tuple, List
 
8
 
9
 
10
  class IntentClassifier:
11
  """Classifies user intents and provides appropriate responses"""
12
-
13
  def __init__(self):
14
  self.intent_patterns = {
15
- 'greeting': {
16
- 'patterns': [
17
- r'\b(hi|hello|hey|good morning|good afternoon|good evening|greetings)\b',
18
- r'^(hi|hello|hey)[\s!.]*$',
19
- r'\b(how are you|how do you do)\b'
20
  ],
21
- 'responses': [
22
  "Hello! I'm XENO Assistant. How can I help you with XENO financial services today?",
23
  "Hi there! I'm here to assist you with any questions about XENO services. What can I help you with?",
24
- "Good day! Welcome to XENO Support. How may I assist you today?"
25
- ]
26
  },
27
- 'thanks': {
28
- 'patterns': [
29
- r'\b(thank you|thanks|thank u|thx|appreciate|grateful)\b',
30
- r'^(thanks|thank you)[\s!.]*$',
31
- r'\b(much appreciated|thanks a lot|thank you so much)\b'
32
  ],
33
- 'responses': [
34
  "You're welcome! Is there anything else I can help you with regarding XENO services?",
35
  "Happy to help! Feel free to ask if you have any other questions about XENO.",
36
- "Glad I could assist you! Let me know if you need help with anything else."
37
- ]
38
  },
39
- 'goodbye': {
40
- 'patterns': [
41
- r'\b(bye|goodbye|see you|farewell|take care|have a good day)\b',
42
- r'^(bye|goodbye)[\s!.]*$',
43
- r'\b(talk to you later|see you later|until next time)\b'
44
  ],
45
- 'responses': [
46
  "Goodbye! Thank you for using XENO services. Have a great day!",
47
  "Take care! Feel free to return anytime you need help with XENO services.",
48
- "Have a wonderful day! Don't hesitate to reach out if you need assistance with XENO."
49
- ]
50
- }
51
  }
52
-
53
  def classify_intent(self, message: str, timer=None) -> Tuple[str, str]:
54
  """
55
  Classify the intent of a user message
56
-
57
  Args:
58
  message: User's message
59
  timer: Optional timer object for tracking
60
-
61
  Returns:
62
  Tuple of (intent_name, response_text)
63
  """
@@ -66,42 +67,42 @@ class IntentClassifier:
66
  return self._classify_intent_impl(message)
67
  else:
68
  return self._classify_intent_impl(message)
69
-
70
  def _classify_intent_impl(self, message: str) -> Tuple[str, str]:
71
  """Internal implementation of intent classification"""
72
  message_lower = message.lower().strip()
73
-
74
  for intent_name, intent_data in self.intent_patterns.items():
75
- for pattern in intent_data['patterns']:
76
  if re.search(pattern, message_lower, re.IGNORECASE):
77
- response = random.choice(intent_data['responses'])
78
  return intent_name, response
79
-
80
- return 'query', ''
81
-
82
  def is_simple_intent(self, intent: str) -> bool:
83
  """
84
  Check if the intent is a simple one that doesn't require RAG
85
-
86
  Args:
87
  intent: Intent name
88
-
89
  Returns:
90
  True if simple intent, False otherwise
91
  """
92
- simple_intents = ['greeting', 'thanks']
93
  return intent in simple_intents
94
-
95
  def add_intent(self, intent_name: str, patterns: List[str], responses: List[str]):
96
  """
97
  Add a new intent to the classifier
98
-
99
  Args:
100
  intent_name: Name of the intent
101
  patterns: List of regex patterns to match
102
  responses: List of possible responses
103
  """
104
  self.intent_patterns[intent_name] = {
105
- 'patterns': patterns,
106
- 'responses': responses
107
  }
 
2
  Intent Classification module for XENO Bot
3
  Handles classification of user intents (greetings, thanks, goodbye, queries)
4
  """
5
+
6
  import random
7
+ import re
8
+ from typing import List, Tuple
9
 
10
 
11
  class IntentClassifier:
12
  """Classifies user intents and provides appropriate responses"""
13
+
14
  def __init__(self):
15
  self.intent_patterns = {
16
+ "greeting": {
17
+ "patterns": [
18
+ r"\b(hi|hello|hey|good morning|good afternoon|good evening|greetings)\b",
19
+ r"^(hi|hello|hey)[\s!.]*$",
20
+ r"\b(how are you|how do you do)\b",
21
  ],
22
+ "responses": [
23
  "Hello! I'm XENO Assistant. How can I help you with XENO financial services today?",
24
  "Hi there! I'm here to assist you with any questions about XENO services. What can I help you with?",
25
+ "Good day! Welcome to XENO Support. How may I assist you today?",
26
+ ],
27
  },
28
+ "thanks": {
29
+ "patterns": [
30
+ r"\b(thank you|thanks|thank u|thx|appreciate|grateful)\b",
31
+ r"^(thanks|thank you)[\s!.]*$",
32
+ r"\b(much appreciated|thanks a lot|thank you so much)\b",
33
  ],
34
+ "responses": [
35
  "You're welcome! Is there anything else I can help you with regarding XENO services?",
36
  "Happy to help! Feel free to ask if you have any other questions about XENO.",
37
+ "Glad I could assist you! Let me know if you need help with anything else.",
38
+ ],
39
  },
40
+ "goodbye": {
41
+ "patterns": [
42
+ r"\b(bye|goodbye|see you|farewell|take care|have a good day)\b",
43
+ r"^(bye|goodbye)[\s!.]*$",
44
+ r"\b(talk to you later|see you later|until next time)\b",
45
  ],
46
+ "responses": [
47
  "Goodbye! Thank you for using XENO services. Have a great day!",
48
  "Take care! Feel free to return anytime you need help with XENO services.",
49
+ "Have a wonderful day! Don't hesitate to reach out if you need assistance with XENO.",
50
+ ],
51
+ },
52
  }
53
+
54
  def classify_intent(self, message: str, timer=None) -> Tuple[str, str]:
55
  """
56
  Classify the intent of a user message
57
+
58
  Args:
59
  message: User's message
60
  timer: Optional timer object for tracking
61
+
62
  Returns:
63
  Tuple of (intent_name, response_text)
64
  """
 
67
  return self._classify_intent_impl(message)
68
  else:
69
  return self._classify_intent_impl(message)
70
+
71
  def _classify_intent_impl(self, message: str) -> Tuple[str, str]:
72
  """Internal implementation of intent classification"""
73
  message_lower = message.lower().strip()
74
+
75
  for intent_name, intent_data in self.intent_patterns.items():
76
+ for pattern in intent_data["patterns"]:
77
  if re.search(pattern, message_lower, re.IGNORECASE):
78
+ response = random.choice(intent_data["responses"])
79
  return intent_name, response
80
+
81
+ return "query", ""
82
+
83
  def is_simple_intent(self, intent: str) -> bool:
84
  """
85
  Check if the intent is a simple one that doesn't require RAG
86
+
87
  Args:
88
  intent: Intent name
89
+
90
  Returns:
91
  True if simple intent, False otherwise
92
  """
93
+ simple_intents = ["greeting", "thanks"]
94
  return intent in simple_intents
95
+
96
  def add_intent(self, intent_name: str, patterns: List[str], responses: List[str]):
97
  """
98
  Add a new intent to the classifier
99
+
100
  Args:
101
  intent_name: Name of the intent
102
  patterns: List of regex patterns to match
103
  responses: List of possible responses
104
  """
105
  self.intent_patterns[intent_name] = {
106
+ "patterns": patterns,
107
+ "responses": responses,
108
  }
src/interface.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import List
3
+
4
+ import gradio as gr
5
+
6
+ from src.logger import log_feedback
7
+
8
+
9
+ def respond(
10
+ message: str, history: List, session_id: str, intent_classifier, retriever
11
+ ):
12
+ """
13
+ Gradio's main response function
14
+
15
+ Args:
16
+ message: User's message
17
+ history: Chat history
18
+ session_id: Session identifier
19
+ intent_classifier: IntentClassifier instance
20
+ retriever: Vector store retriever instance
21
+
22
+ Returns:
23
+ Tuple of (empty string for input box, updated history)
24
+ """
25
+ # Import here to avoid circular imports
26
+ from app import get_context_and_answer
27
+
28
+ if not session_id:
29
+ session_id = str(uuid.uuid4())
30
+
31
+ bot_response = get_context_and_answer(
32
+ message, history, session_id, intent_classifier, retriever
33
+ )
34
+ history.append([message, bot_response])
35
+
36
+ return "", history
37
+
38
+
39
+ def create_interface(intent_classifier, retriever):
40
+ """
41
+ Create Gradio interface
42
+
43
+ Args:
44
+ intent_classifier: IntentClassifier instance
45
+ retriever: Vector store retriever instance
46
+
47
+ Returns:
48
+ Gradio Blocks interface
49
+ """
50
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
51
+ gr.Markdown("""
52
+ # ASKXENO
53
+ **Welcome to XENO AI Support!**
54
+
55
+ I can help you with questions about XENO financial services including:
56
+ - Account management and setup
57
+ - Transaction processes and fees
58
+ - Platform features and troubleshooting
59
+ - General service information
60
+
61
+ *Simply type your question below to get started!*
62
+ """)
63
+
64
+ # Hidden state for session
65
+ session_id_box = gr.Textbox(
66
+ label="Session ID", value=str(uuid.uuid4()), visible=False
67
+ )
68
+
69
+ chatbot = gr.Chatbot(
70
+ label="XENO Assistant", bubble_full_width=False, height=450
71
+ )
72
+
73
+ with gr.Row():
74
+ msg = gr.Textbox(
75
+ label="Your Message",
76
+ placeholder="Type your question here...",
77
+ scale=4,
78
+ )
79
+ send_button = gr.Button("Send", variant="primary", scale=1)
80
+
81
+ # ===== FEEDBACK SECTION =====
82
+ with gr.Row():
83
+ with gr.Accordion("Rate this response / Flag Issue", open=False):
84
+ with gr.Row():
85
+ thumbs_up = gr.Button("👍 Good Answer")
86
+ thumbs_down = gr.Button("👎 Bad / Flag")
87
+
88
+ feedback_reason = gr.Textbox(
89
+ label="Reason ", placeholder="E.g., Incorrect fees, hallucination,"
90
+ )
91
+ feedback_status = gr.Label(value="", label="Status", show_label=False)
92
+
93
+ # Feedback Event Listeners
94
+ # Logic: If Thumbs Up is clicked, send 'Positive'. If Textbox is empty, reason defaults to "Good".
95
+ thumbs_up.click(
96
+ fn=lambda h, s, r: log_feedback("Positive", r if r else "Good", h, s),
97
+ inputs=[chatbot, session_id_box, feedback_reason],
98
+ outputs=[feedback_status],
99
+ )
100
+
101
+ # Logic: If Thumbs Down is clicked, send 'Negative' with the content of the textbox.
102
+ thumbs_down.click(
103
+ fn=lambda r, h, s: log_feedback("Negative", r, h, s),
104
+ inputs=[feedback_reason, chatbot, session_id_box],
105
+ outputs=[feedback_status],
106
+ )
107
+ # =============================
108
+
109
+ # Chat Event Listeners - Pass components to respond function
110
+ send_button.click(
111
+ lambda msg, chat, sid: respond(msg, chat, sid, intent_classifier, retriever),
112
+ [msg, chatbot, session_id_box],
113
+ [msg, chatbot],
114
+ )
115
+ msg.submit(
116
+ lambda msg, chat, sid: respond(msg, chat, sid, intent_classifier, retriever),
117
+ [msg, chatbot, session_id_box],
118
+ [msg, chatbot],
119
+ )
120
+
121
+ return demo
src/knowledge_base.py CHANGED
@@ -2,68 +2,80 @@
2
  Knowledge Base module for XENO Bot
3
  Handles loading and preparing knowledge base data
4
  """
 
 
 
5
  import pandas as pd
6
- from typing import List, Dict, Tuple, Any
7
  from src.config import KNOWLEDGE_BASE_PATH
8
 
9
 
10
  def load_knowledge_base(filepath: str = KNOWLEDGE_BASE_PATH) -> pd.DataFrame:
11
  """
12
  Load knowledge base from JSON file
13
-
14
  Args:
15
  filepath: Path to the knowledge base JSON file
16
-
17
  Returns:
18
  DataFrame with knowledge base data
19
  """
20
- df = pd.read_json(filepath)
21
- df.dropna(subset=['Content'], inplace=True)
 
 
 
 
22
  return df
23
 
24
 
25
- def prepare_documents(data: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict], List[str]]:
 
 
26
  """
27
  Prepare documents for vector store
28
-
29
  Args:
30
  data: List of knowledge base entries
31
-
32
  Returns:
33
  Tuple of (documents, metadatas, ids)
34
  """
35
  documents, metadatas, ids = [], [], []
36
-
37
- for item in data:
38
- # Create document text with question and answer
39
- document_text = f"Question: {item['Question']}\nAnswer: {item['Content']}"
40
- documents.append(document_text)
41
-
42
- # Create metadata
43
- metadata = {
44
- "question": item["Question"],
45
- "content": item["Content"],
46
- "section": item.get("Section", ""),
47
- "source": item.get("Source", ""),
48
- "owner": item.get("Owner", ""),
49
- "tag": item.get("Tag", ""),
50
- "id": item["ID"]
51
- }
52
- metadatas.append(metadata)
53
-
54
- # Add ID
55
- ids.append(item["ID"])
56
-
 
 
 
57
  return documents, metadatas, ids
58
 
59
 
60
  def get_knowledge_base_data() -> Tuple[List[str], List[Dict], List[str]]:
61
  """
62
  Load and prepare knowledge base data
63
-
64
  Returns:
65
  Tuple of (documents, metadatas, ids)
66
  """
67
  df = load_knowledge_base()
68
- data_list = df.to_dict('records')
69
  return prepare_documents(data_list)
 
2
  Knowledge Base module for XENO Bot
3
  Handles loading and preparing knowledge base data
4
  """
5
+
6
+ from typing import Any, Dict, Hashable, List, Tuple
7
+
8
  import pandas as pd
9
+
10
  from src.config import KNOWLEDGE_BASE_PATH
11
 
12
 
13
  def load_knowledge_base(filepath: str = KNOWLEDGE_BASE_PATH) -> pd.DataFrame:
14
  """
15
  Load knowledge base from JSON file
16
+
17
  Args:
18
  filepath: Path to the knowledge base JSON file
19
+
20
  Returns:
21
  DataFrame with knowledge base data
22
  """
23
+ try:
24
+ df = pd.read_json(filepath)
25
+ df.dropna(subset=["Content"], inplace=True)
26
+ except Exception as e:
27
+ print(f"Error loading knowledge base: {e}")
28
+ df = pd.DataFrame()
29
  return df
30
 
31
 
32
+ def prepare_documents(
33
+ data: List[Dict[Hashable, Any]],
34
+ ) -> Tuple[List[str], List[Dict], List[str]]:
35
  """
36
  Prepare documents for vector store
37
+
38
  Args:
39
  data: List of knowledge base entries
40
+
41
  Returns:
42
  Tuple of (documents, metadatas, ids)
43
  """
44
  documents, metadatas, ids = [], [], []
45
+
46
+ try:
47
+ for item in data:
48
+ # Create document text with question and answer
49
+ document_text = f"Question: {item['Question']}\nAnswer: {item['Content']}"
50
+ documents.append(document_text)
51
+
52
+ # Create metadata
53
+ metadata = {
54
+ "question": item["Question"],
55
+ "content": item["Content"],
56
+ "section": item.get("Section", ""),
57
+ "source": item.get("Source", ""),
58
+ "owner": item.get("Owner", ""),
59
+ "tag": item.get("Tag", ""),
60
+ "id": item["ID"],
61
+ }
62
+ metadatas.append(metadata)
63
+
64
+ # Add ID
65
+ ids.append(item["ID"])
66
+ except KeyError as e:
67
+ print(f"Missing expected key in data item: {e}")
68
+
69
  return documents, metadatas, ids
70
 
71
 
72
  def get_knowledge_base_data() -> Tuple[List[str], List[Dict], List[str]]:
73
  """
74
  Load and prepare knowledge base data
75
+
76
  Returns:
77
  Tuple of (documents, metadatas, ids)
78
  """
79
  df = load_knowledge_base()
80
+ data_list = df.to_dict("records")
81
  return prepare_documents(data_list)
src/logger.py CHANGED
@@ -2,81 +2,145 @@
2
  Logging module for XENO Bot
3
  Handles Google Sheets logging for responses and timing data
4
  """
 
5
  import json
6
  import os
 
7
  from datetime import datetime
8
- from typing import List, Tuple, Dict, Optional
 
9
  import gspread
10
  from google.oauth2.service_account import Credentials
11
- from src.config import (
12
- GOOGLE_SHEETS_CREDENTIALS_ENV,
13
- SPREADSHEET_NAME,
14
- RESPONSE_SHEET_INDEX,
15
- TIMING_SHEET_NAME
16
- )
17
 
18
 
19
  def get_google_sheets_credentials() -> Credentials:
20
  """
21
  Get Google Sheets credentials from environment variable
22
-
23
  Returns:
24
  Google Sheets credentials object
25
  """
26
  credentials_json = os.environ.get(GOOGLE_SHEETS_CREDENTIALS_ENV)
27
  if not credentials_json:
28
- raise ValueError(f"{GOOGLE_SHEETS_CREDENTIALS_ENV} environment variable not set.")
29
-
 
 
30
  credentials_dict = json.loads(credentials_json)
31
  scope = [
32
- "https://spreadsheets.google.com/feeds",
33
- "https://www.googleapis.com/auth/drive"
34
  ]
35
  creds = Credentials.from_service_account_info(credentials_dict, scopes=scope)
36
-
37
  return creds
38
 
39
 
40
  def initialize_sheets():
41
  """
42
  Initialize Google Sheets client and get sheets
43
-
44
  Returns:
45
  Tuple of (response_sheet, timing_sheet)
46
  """
47
- client_gspread = gspread.authorize(get_google_sheets_credentials())
48
- spreadsheet = client_gspread.open(SPREADSHEET_NAME)
49
-
50
- # Get response sheet
51
- response_sheet = spreadsheet.get_worksheet(RESPONSE_SHEET_INDEX)
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Get or create timing sheet
54
  try:
55
  timing_sheet = spreadsheet.worksheet(TIMING_SHEET_NAME)
56
  except:
57
  # Create timing sheet if it doesn't exist
58
- timing_sheet = spreadsheet.add_worksheet(title=TIMING_SHEET_NAME, rows="1000", cols="15")
59
- # Add headers
60
- headers = [
61
- "Timestamp", "Session_ID", "Question", "Total_Time_MS",
62
- "Intent_Classification_MS", "Memory_Retrieval_MS", "RAG_Retrieval_MS",
63
- "Embedding_Generation_MS", "Similarity_Calculation_MS", "Context_Processing_MS",
64
- "LLM_Generation_MS", "Memory_Update_MS", "Logging_MS", "Error_Step", "Notes"
65
- ]
66
- timing_sheet.append_row(headers)
67
-
68
- return response_sheet, timing_sheet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  # Initialize sheets
72
- response_sheet, timing_sheet = initialize_sheets()
73
 
74
 
75
- def log_response(question: str, answer: str, source_ids: str,
76
- knowledge_pairs: List[Tuple[str, str]], session_id: str, timer=None):
 
 
 
 
 
 
77
  """
78
  Log response to Google Sheets
79
-
80
  Args:
81
  question: User's question
82
  answer: Generated answer
@@ -87,28 +151,41 @@ def log_response(question: str, answer: str, source_ids: str,
87
  """
88
  if timer:
89
  with timer.time_step("response_logging"):
90
- _log_response_impl(question, answer, source_ids, knowledge_pairs, session_id)
 
 
91
  else:
92
  _log_response_impl(question, answer, source_ids, knowledge_pairs, session_id)
93
 
94
 
95
- def _log_response_impl(question: str, answer: str, source_ids: str,
96
- knowledge_pairs: List[Tuple[str, str]], session_id: str):
 
 
 
 
 
97
  """Internal implementation of response logging"""
98
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
99
-
100
  # Extract knowledge pairs
101
  knowledge_question_1 = knowledge_pairs[0][0] if len(knowledge_pairs) > 0 else "N/A"
102
  knowledge_answer_1 = knowledge_pairs[0][1] if len(knowledge_pairs) > 0 else "N/A"
103
  knowledge_question_2 = knowledge_pairs[1][0] if len(knowledge_pairs) > 1 else "N/A"
104
  knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
105
-
106
  row = [
107
- timestamp, session_id, question, answer, source_ids,
108
- knowledge_question_1, knowledge_answer_1,
109
- knowledge_question_2, knowledge_answer_2
 
 
 
 
 
 
110
  ]
111
-
112
  try:
113
  response_sheet.append_row(row)
114
  print(f"Logged response: {question} | Source IDs: {source_ids}")
@@ -116,14 +193,21 @@ def _log_response_impl(question: str, answer: str, source_ids: str,
116
  print(f"Failed to log to Google Sheet: {e}")
117
  # Fallback to local file
118
  with open("/tmp/response_log.txt", "a") as f:
119
- f.write(f"{timestamp},{question},{answer},{source_ids},{knowledge_question_1},{knowledge_answer_1},{knowledge_question_2},{knowledge_answer_2}\n")
 
 
120
 
121
 
122
- def log_timing_data(question: str, session_id: str, timing_summary: Dict,
123
- error_step: Optional[str] = None, notes: Optional[str] = None):
 
 
 
 
 
124
  """
125
  Log timing data to Google Sheets
126
-
127
  Args:
128
  question: User's question
129
  session_id: Session identifier
@@ -132,29 +216,29 @@ def log_timing_data(question: str, session_id: str, timing_summary: Dict,
132
  notes: Additional notes
133
  """
134
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
135
- step_times = timing_summary['step_times']
136
-
137
  # Truncate long questions
138
  truncated_question = question[:100] + "..." if len(question) > 100 else question
139
-
140
  row = [
141
  timestamp,
142
  session_id,
143
  truncated_question,
144
- timing_summary['total_time_ms'],
145
- step_times.get('intent_classification', 0),
146
- step_times.get('memory_retrieval', 0),
147
- step_times.get('rag_retrieval', 0),
148
- step_times.get('embedding_generation', 0),
149
- step_times.get('similarity_calculation', 0),
150
- step_times.get('context_processing', 0),
151
- step_times.get('llm_generation', 0),
152
- step_times.get('memory_update', 0),
153
- step_times.get('response_logging', 0),
154
  error_step or "",
155
- notes or ""
156
  ]
157
-
158
  try:
159
  timing_sheet.append_row(row)
160
  print(f"Logged timing data: Total {timing_summary['total_time_ms']}ms")
@@ -163,3 +247,46 @@ def log_timing_data(question: str, session_id: str, timing_summary: Dict,
163
  # Fallback to local file
164
  with open("/tmp/timing_log.txt", "a") as f:
165
  f.write(f"{timestamp},{session_id},{question},{timing_summary}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  Logging module for XENO Bot
3
  Handles Google Sheets logging for responses and timing data
4
  """
5
+
6
  import json
7
  import os
8
+ import threading
9
  from datetime import datetime
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
  import gspread
13
  from google.oauth2.service_account import Credentials
14
+
15
+ from src.config import (FEEDBACK_SHEET_NAME, GOOGLE_SHEETS_CREDENTIALS_ENV,
16
+ RESPONSE_SHEET_INDEX, SPREADSHEET_NAME,
17
+ TIMING_SHEET_NAME)
 
 
18
 
19
 
20
  def get_google_sheets_credentials() -> Credentials:
21
  """
22
  Get Google Sheets credentials from environment variable
23
+
24
  Returns:
25
  Google Sheets credentials object
26
  """
27
  credentials_json = os.environ.get(GOOGLE_SHEETS_CREDENTIALS_ENV)
28
  if not credentials_json:
29
+ raise ValueError(
30
+ f"{GOOGLE_SHEETS_CREDENTIALS_ENV} environment variable not set."
31
+ )
32
+
33
  credentials_dict = json.loads(credentials_json)
34
  scope = [
35
+ "https://spreadsheets.google.com/feeds",
36
+ "https://www.googleapis.com/auth/drive",
37
  ]
38
  creds = Credentials.from_service_account_info(credentials_dict, scopes=scope)
39
+
40
  return creds
41
 
42
 
43
  def initialize_sheets():
44
  """
45
  Initialize Google Sheets client and get sheets
46
+
47
  Returns:
48
  Tuple of (response_sheet, timing_sheet)
49
  """
50
+ try:
51
+ client_gspread = gspread.authorize(get_google_sheets_credentials())
52
+ spreadsheet = client_gspread.open(SPREADSHEET_NAME)
53
+
54
+ # Get response sheet
55
+ response_sheet = spreadsheet.get_worksheet(RESPONSE_SHEET_INDEX)
56
+ except Exception as e:
57
+ print(f"Failed to initialize Google Sheets: {e}")
58
+
59
+ # TODO Create dummy sheets or handle error appropriately
60
+ class DummySheet:
61
+ def append_row(self, *args, **kwargs):
62
+ pass
63
+
64
+ def worksheet(self, *args):
65
+ return self
66
+
67
+ def add_worksheet(self, *args, **kwargs):
68
+ return self
69
+
70
+ spreadsheet = DummySheet()
71
+ response_sheet = DummySheet()
72
+
73
  # Get or create timing sheet
74
  try:
75
  timing_sheet = spreadsheet.worksheet(TIMING_SHEET_NAME)
76
  except:
77
  # Create timing sheet if it doesn't exist
78
+ try:
79
+ timing_sheet = spreadsheet.add_worksheet(
80
+ title=TIMING_SHEET_NAME, rows=1000, cols=15
81
+ )
82
+ # Add headers
83
+ headers = [
84
+ "Timestamp",
85
+ "Session_ID",
86
+ "Question",
87
+ "Total_Time_MS",
88
+ "Intent_Classification_MS",
89
+ "Memory_Retrieval_MS",
90
+ "RAG_Retrieval_MS",
91
+ "Embedding_Generation_MS",
92
+ "Similarity_Calculation_MS",
93
+ "Context_Processing_MS",
94
+ "LLM_Generation_MS",
95
+ "Memory_Update_MS",
96
+ "Logging_MS",
97
+ "Error_Step",
98
+ "Notes",
99
+ ]
100
+ timing_sheet.append_row(headers)
101
+ except Exception as e:
102
+ print(f"Failed to create timing sheet: {e}")
103
+ timing_sheet = DummySheet()
104
+
105
+ # Feedback Sheet
106
+ try:
107
+ feedback_sheet = spreadsheet.worksheet(FEEDBACK_SHEET_NAME)
108
+ except:
109
+ try:
110
+ feedback_sheet = spreadsheet.add_worksheet(
111
+ title=FEEDBACK_SHEET_NAME, rows=1000, cols=6
112
+ )
113
+ headers = [
114
+ "Timestamp",
115
+ "Session_ID",
116
+ "User_Message",
117
+ "Bot_Response",
118
+ "Rating",
119
+ "Flag_Reason",
120
+ ]
121
+ feedback_sheet.append_row(headers)
122
+ except Exception as e:
123
+ print(f"Failed to create feedback sheet: {e}")
124
+ feedback_sheet = DummySheet()
125
+
126
+ return response_sheet, timing_sheet, feedback_sheet
127
 
128
 
129
  # Initialize sheets
130
+ response_sheet, timing_sheet, feedback_sheet = initialize_sheets()
131
 
132
 
133
+ def log_response(
134
+ question: str,
135
+ answer: str,
136
+ source_ids: str,
137
+ knowledge_pairs: List[Tuple[str, str]],
138
+ session_id: str,
139
+ timer=None,
140
+ ):
141
  """
142
  Log response to Google Sheets
143
+
144
  Args:
145
  question: User's question
146
  answer: Generated answer
 
151
  """
152
  if timer:
153
  with timer.time_step("response_logging"):
154
+ _log_response_impl(
155
+ question, answer, source_ids, knowledge_pairs, session_id
156
+ )
157
  else:
158
  _log_response_impl(question, answer, source_ids, knowledge_pairs, session_id)
159
 
160
 
161
+ def _log_response_impl(
162
+ question: str,
163
+ answer: str,
164
+ source_ids: str,
165
+ knowledge_pairs: List[Tuple[str, str]],
166
+ session_id: str,
167
+ ):
168
  """Internal implementation of response logging"""
169
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
170
+
171
  # Extract knowledge pairs
172
  knowledge_question_1 = knowledge_pairs[0][0] if len(knowledge_pairs) > 0 else "N/A"
173
  knowledge_answer_1 = knowledge_pairs[0][1] if len(knowledge_pairs) > 0 else "N/A"
174
  knowledge_question_2 = knowledge_pairs[1][0] if len(knowledge_pairs) > 1 else "N/A"
175
  knowledge_answer_2 = knowledge_pairs[1][1] if len(knowledge_pairs) > 1 else "N/A"
176
+
177
  row = [
178
+ timestamp,
179
+ session_id,
180
+ question,
181
+ answer,
182
+ source_ids,
183
+ knowledge_question_1,
184
+ knowledge_answer_1,
185
+ knowledge_question_2,
186
+ knowledge_answer_2,
187
  ]
188
+
189
  try:
190
  response_sheet.append_row(row)
191
  print(f"Logged response: {question} | Source IDs: {source_ids}")
 
193
  print(f"Failed to log to Google Sheet: {e}")
194
  # Fallback to local file
195
  with open("/tmp/response_log.txt", "a") as f:
196
+ f.write(
197
+ f"{timestamp},{question},{answer},{source_ids},{knowledge_question_1},{knowledge_answer_1},{knowledge_question_2},{knowledge_answer_2}\n"
198
+ )
199
 
200
 
201
+ def log_timing_data(
202
+ question: str,
203
+ session_id: str,
204
+ timing_summary: Dict,
205
+ error_step: Optional[str] = None,
206
+ notes: Optional[str] = None,
207
+ ):
208
  """
209
  Log timing data to Google Sheets
210
+
211
  Args:
212
  question: User's question
213
  session_id: Session identifier
 
216
  notes: Additional notes
217
  """
218
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
219
+ step_times = timing_summary["step_times"]
220
+
221
  # Truncate long questions
222
  truncated_question = question[:100] + "..." if len(question) > 100 else question
223
+
224
  row = [
225
  timestamp,
226
  session_id,
227
  truncated_question,
228
+ timing_summary["total_time_ms"],
229
+ step_times.get("intent_classification", 0),
230
+ step_times.get("memory_retrieval", 0),
231
+ step_times.get("rag_retrieval", 0),
232
+ step_times.get("embedding_generation", 0),
233
+ step_times.get("similarity_calculation", 0),
234
+ step_times.get("context_processing", 0),
235
+ step_times.get("llm_generation", 0),
236
+ step_times.get("memory_update", 0),
237
+ step_times.get("response_logging", 0),
238
  error_step or "",
239
+ notes or "",
240
  ]
241
+
242
  try:
243
  timing_sheet.append_row(row)
244
  print(f"Logged timing data: Total {timing_summary['total_time_ms']}ms")
 
247
  # Fallback to local file
248
  with open("/tmp/timing_log.txt", "a") as f:
249
  f.write(f"{timestamp},{session_id},{question},{timing_summary}\n")
250
+
251
+
252
+ def _log_feedback_background(row):
253
+ """Helper to run network request in background thread"""
254
+ try:
255
+ if feedback_sheet:
256
+ feedback_sheet.append_row(row)
257
+ print("Feedback logged successfully.")
258
+ else:
259
+ print("Feedback sheet not available.")
260
+ except Exception as e:
261
+ print(f"Failed to log feedback: {e}")
262
+
263
+
264
+ def log_feedback(rating, reason, history, session_id):
265
+ """
266
+ Handles user feedback submission.
267
+ rating: 'Positive' or 'Negative'
268
+ reason: User provided text
269
+ history: Gradio chat history list
270
+ """
271
+ if not history or len(history) == 0:
272
+ return "No conversation to rate yet."
273
+
274
+ # Get the last interaction (Gradio history is a list of lists: [[user, bot], ...])
275
+ last_interaction = history[-1]
276
+
277
+ # Safety check for history format
278
+ if isinstance(last_interaction, list) and len(last_interaction) >= 2:
279
+ user_msg = last_interaction[0]
280
+ bot_msg = last_interaction[1]
281
+ else:
282
+ return "Error reading conversation history."
283
+
284
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
285
+
286
+ # Prepare row data
287
+ row = [timestamp, session_id, user_msg, bot_msg, rating, reason]
288
+
289
+ # Run in thread to prevent UI blocking
290
+ threading.Thread(target=_log_feedback_background, args=(row,)).start()
291
+
292
+ return f"Feedback received ({rating}). Thank you!"
src/memory.py CHANGED
@@ -2,11 +2,14 @@
2
  Memory module for XENO Bot
3
  Handles LangGraph memory operations using SQLite
4
  """
5
- import uuid
6
  import sqlite3
 
7
  from datetime import datetime
8
- from typing import List, Dict, Any
 
9
  from langgraph.checkpoint.sqlite import SqliteSaver
 
10
  from src.config import SQLITE_DB_PATH
11
 
12
  # === LangGraph Memory Setup ===
@@ -14,10 +17,12 @@ conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
14
  memory = SqliteSaver(conn=conn)
15
 
16
 
17
- def update_memory(config: Dict[str, Any], user_message: str, assistant_message: str, timer=None):
 
 
18
  """
19
  Update memory with new messages
20
-
21
  Args:
22
  config: Configuration dictionary with thread_id
23
  user_message: User's message
@@ -31,34 +36,34 @@ def update_memory(config: Dict[str, Any], user_message: str, assistant_message:
31
  _update_memory_impl(config, user_message, assistant_message)
32
 
33
 
34
- def _update_memory_impl(config: Dict[str, Any], user_message: str, assistant_message: str):
35
  """Internal implementation of memory update"""
36
  full_checkpoint = memory.get(config) or {}
37
  messages = full_checkpoint.get("channel_values", {}).get("messages", [])
38
-
39
  messages.append({"role": "user", "content": user_message})
40
  messages.append({"role": "assistant", "content": assistant_message})
41
-
42
  checkpoint_to_save = {
43
  "v": 1,
44
  "id": str(uuid.uuid4()),
45
  "ts": datetime.now().isoformat(),
46
  "channel_values": {"messages": messages},
47
  "channel_versions": {},
48
- "versions_seen": {},
49
  }
50
-
51
  memory.put(config, checkpoint_to_save, {}, {})
52
 
53
 
54
  def retrieve_memory(config: Dict[str, Any], timer=None) -> List[Dict[str, str]]:
55
  """
56
  Retrieve memory messages for a session
57
-
58
  Args:
59
  config: Configuration dictionary with thread_id
60
  timer: Optional timer object for tracking
61
-
62
  Returns:
63
  List of message dictionaries
64
  """
@@ -69,7 +74,7 @@ def retrieve_memory(config: Dict[str, Any], timer=None) -> List[Dict[str, str]]:
69
  return _retrieve_memory_impl(config)
70
 
71
 
72
- def _retrieve_memory_impl(config: Dict[str, Any]) -> List[Dict[str, str]]:
73
  """Internal implementation of memory retrieval"""
74
  full_checkpoint = memory.get(config) or {}
75
  return full_checkpoint.get("channel_values", {}).get("messages", [])
@@ -78,10 +83,10 @@ def _retrieve_memory_impl(config: Dict[str, Any]) -> List[Dict[str, str]]:
78
  def create_session_config(session_id: str = "default") -> Dict[str, Any]:
79
  """
80
  Create a configuration dictionary for a session
81
-
82
  Args:
83
  session_id: Unique session identifier
84
-
85
  Returns:
86
  Configuration dictionary
87
  """
 
2
  Memory module for XENO Bot
3
  Handles LangGraph memory operations using SQLite
4
  """
5
+
6
  import sqlite3
7
+ import uuid
8
  from datetime import datetime
9
+ from typing import Any, Dict, List
10
+
11
  from langgraph.checkpoint.sqlite import SqliteSaver
12
+
13
  from src.config import SQLITE_DB_PATH
14
 
15
  # === LangGraph Memory Setup ===
 
17
  memory = SqliteSaver(conn=conn)
18
 
19
 
20
+ def update_memory(
21
+ config: Dict[str, Any], user_message: str, assistant_message: str, timer=None
22
+ ):
23
  """
24
  Update memory with new messages
25
+
26
  Args:
27
  config: Configuration dictionary with thread_id
28
  user_message: User's message
 
36
  _update_memory_impl(config, user_message, assistant_message)
37
 
38
 
39
+ def _update_memory_impl(config, user_message: str, assistant_message: str):
40
  """Internal implementation of memory update"""
41
  full_checkpoint = memory.get(config) or {}
42
  messages = full_checkpoint.get("channel_values", {}).get("messages", [])
43
+
44
  messages.append({"role": "user", "content": user_message})
45
  messages.append({"role": "assistant", "content": assistant_message})
46
+
47
  checkpoint_to_save = {
48
  "v": 1,
49
  "id": str(uuid.uuid4()),
50
  "ts": datetime.now().isoformat(),
51
  "channel_values": {"messages": messages},
52
  "channel_versions": {},
53
+ "versions_seen": {},
54
  }
55
+
56
  memory.put(config, checkpoint_to_save, {}, {})
57
 
58
 
59
  def retrieve_memory(config: Dict[str, Any], timer=None) -> List[Dict[str, str]]:
60
  """
61
  Retrieve memory messages for a session
62
+
63
  Args:
64
  config: Configuration dictionary with thread_id
65
  timer: Optional timer object for tracking
66
+
67
  Returns:
68
  List of message dictionaries
69
  """
 
74
  return _retrieve_memory_impl(config)
75
 
76
 
77
+ def _retrieve_memory_impl(config) -> List[Dict[str, str]]:
78
  """Internal implementation of memory retrieval"""
79
  full_checkpoint = memory.get(config) or {}
80
  return full_checkpoint.get("channel_values", {}).get("messages", [])
 
83
  def create_session_config(session_id: str = "default") -> Dict[str, Any]:
84
  """
85
  Create a configuration dictionary for a session
86
+
87
  Args:
88
  session_id: Unique session identifier
89
+
90
  Returns:
91
  Configuration dictionary
92
  """
src/response_generator.py CHANGED
@@ -2,21 +2,24 @@
2
  Response Generation module for XENO Bot
3
  Handles LLM response generation
4
  """
5
- from google import genai
6
- from typing import List, Dict
7
- from src.config import LLM_MODEL_NAME, SYSTEM_PROMPT, client
8
 
 
9
 
10
- def generate_xeno_response(context: str, question: str, chat_history: List[Dict[str, str]], timer=None) -> str:
 
 
 
 
 
11
  """
12
  Generate a response using the LLM
13
-
14
  Args:
15
  context: Formatted context from knowledge base
16
  question: User's question
17
  chat_history: List of previous messages
18
  timer: Optional timer object for tracking
19
-
20
  Returns:
21
  Generated response text
22
  """
@@ -27,42 +30,47 @@ def generate_xeno_response(context: str, question: str, chat_history: List[Dict[
27
  return _generate_response_impl(context, question, chat_history)
28
 
29
 
30
- def _generate_response_impl(context: str, question: str, chat_history: List[Dict[str, str]]) -> str:
 
 
31
  """Internal implementation of response generation"""
32
  # Format chat history
33
- formatted_history = "\n".join(
34
- [f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history]
35
- ) if chat_history else "None"
36
-
 
 
 
 
37
  # Build prompt
38
  prompt = f"{SYSTEM_PROMPT}\n### HISTORY ###\n{formatted_history}\n### CONTEXT ###\n{context}\n### QUESTION ###\n{question}"
39
-
40
  # Generate response
41
- response = client.generate_content(
42
- model=LLM_MODEL_NAME,
43
- contents={"text": prompt}
44
  )
45
-
46
  return response.text
47
 
48
 
49
  def format_chat_history(messages: List[Dict[str, str]]) -> str:
50
  """
51
  Format chat history for display or logging
52
-
53
  Args:
54
  messages: List of message dictionaries with 'role' and 'content'
55
-
56
  Returns:
57
  Formatted string representation of chat history
58
  """
59
  if not messages:
60
  return "No previous conversation"
61
-
62
  formatted = []
63
  for msg in messages:
64
- role = msg.get('role', 'unknown').capitalize()
65
- content = msg.get('content', '')
66
  formatted.append(f"{role}: {content}")
67
-
68
  return "\n".join(formatted)
 
2
  Response Generation module for XENO Bot
3
  Handles LLM response generation
4
  """
 
 
 
5
 
6
+ from typing import Dict, List
7
 
8
+ from src.config import LLM_MODEL_NAME, SYSTEM_PROMPT, genai_client
9
+
10
+
11
+ def generate_xeno_response(
12
+ context: str, question: str, chat_history: List[Dict[str, str]], timer=None
13
+ ) -> str:
14
  """
15
  Generate a response using the LLM
16
+
17
  Args:
18
  context: Formatted context from knowledge base
19
  question: User's question
20
  chat_history: List of previous messages
21
  timer: Optional timer object for tracking
22
+
23
  Returns:
24
  Generated response text
25
  """
 
30
  return _generate_response_impl(context, question, chat_history)
31
 
32
 
33
+ def _generate_response_impl(
34
+ context: str, question: str, chat_history: List[Dict[str, str]]
35
+ ) -> str:
36
  """Internal implementation of response generation"""
37
  # Format chat history
38
+ formatted_history = (
39
+ "\n".join(
40
+ [f"{msg['role'].capitalize()}: {msg['content']}" for msg in chat_history]
41
+ )
42
+ if chat_history
43
+ else "None"
44
+ )
45
+
46
  # Build prompt
47
  prompt = f"{SYSTEM_PROMPT}\n### HISTORY ###\n{formatted_history}\n### CONTEXT ###\n{context}\n### QUESTION ###\n{question}"
48
+
49
  # Generate response
50
+ response = genai_client.models.generate_content(
51
+ model=LLM_MODEL_NAME, contents=prompt
 
52
  )
53
+
54
  return response.text
55
 
56
 
57
  def format_chat_history(messages: List[Dict[str, str]]) -> str:
58
  """
59
  Format chat history for display or logging
60
+
61
  Args:
62
  messages: List of message dictionaries with 'role' and 'content'
63
+
64
  Returns:
65
  Formatted string representation of chat history
66
  """
67
  if not messages:
68
  return "No previous conversation"
69
+
70
  formatted = []
71
  for msg in messages:
72
+ role = msg.get("role", "unknown").capitalize()
73
+ content = msg.get("content", "")
74
  formatted.append(f"{role}: {content}")
75
+
76
  return "\n".join(formatted)
src/utils.py CHANGED
@@ -2,6 +2,7 @@
2
  Utilities module for XENO Bot
3
  Handles logging and timing functionality
4
  """
 
5
  import logging
6
  import sys
7
  import time
@@ -13,14 +14,18 @@ from typing import Dict
13
  logging.basicConfig(
14
  filename="app.log",
15
  level=logging.INFO,
16
- format="%(asctime)s - %(levelname)s - %(message)s"
17
  )
18
 
 
19
  def log_exception(exc_type, exc_value, exc_traceback):
20
  """Log uncaught exceptions"""
21
  if issubclass(exc_type, KeyboardInterrupt):
22
  return
23
- logging.critical("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
 
 
 
24
 
25
  sys.excepthook = log_exception
26
  logging.info("App started successfully.")
@@ -29,17 +34,17 @@ logging.info("App started successfully.")
29
  # ===== Time Tracking Class =====
30
  class PipelineTimer:
31
  """Timer for tracking pipeline execution steps"""
32
-
33
  def __init__(self):
34
  self.reset()
35
-
36
  def reset(self):
37
  """Reset all timing data for a new request"""
38
  self.start_time = time.time()
39
  self.step_times = {}
40
  self.step_start = None
41
  self.current_step = None
42
-
43
  @contextmanager
44
  def time_step(self, step_name: str):
45
  """Context manager to time a specific step"""
@@ -49,18 +54,20 @@ class PipelineTimer:
49
  yield
50
  finally:
51
  step_end = time.time()
52
- self.step_times[step_name] = round((step_end - step_start) * 1000, 2) # Convert to milliseconds
 
 
53
  self.current_step = None
54
-
55
  def get_total_time(self):
56
  """Get total elapsed time since reset"""
57
  return round((time.time() - self.start_time) * 1000, 2)
58
-
59
  def get_timing_summary(self) -> Dict:
60
  """Get a summary of all timing data"""
61
  total_time = self.get_total_time()
62
  return {
63
- 'total_time_ms': total_time,
64
- 'step_times': self.step_times,
65
- 'timestamp': datetime.now().isoformat()
66
  }
 
2
  Utilities module for XENO Bot
3
  Handles logging and timing functionality
4
  """
5
+
6
  import logging
7
  import sys
8
  import time
 
14
  logging.basicConfig(
15
  filename="app.log",
16
  level=logging.INFO,
17
+ format="%(asctime)s - %(levelname)s - %(message)s",
18
  )
19
 
20
+
21
  def log_exception(exc_type, exc_value, exc_traceback):
22
  """Log uncaught exceptions"""
23
  if issubclass(exc_type, KeyboardInterrupt):
24
  return
25
+ logging.critical(
26
+ "Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
27
+ )
28
+
29
 
30
  sys.excepthook = log_exception
31
  logging.info("App started successfully.")
 
34
  # ===== Time Tracking Class =====
35
  class PipelineTimer:
36
  """Timer for tracking pipeline execution steps"""
37
+
38
  def __init__(self):
39
  self.reset()
40
+
41
  def reset(self):
42
  """Reset all timing data for a new request"""
43
  self.start_time = time.time()
44
  self.step_times = {}
45
  self.step_start = None
46
  self.current_step = None
47
+
48
  @contextmanager
49
  def time_step(self, step_name: str):
50
  """Context manager to time a specific step"""
 
54
  yield
55
  finally:
56
  step_end = time.time()
57
+ self.step_times[step_name] = round(
58
+ (step_end - step_start) * 1000, 2
59
+ ) # Convert to milliseconds
60
  self.current_step = None
61
+
62
  def get_total_time(self):
63
  """Get total elapsed time since reset"""
64
  return round((time.time() - self.start_time) * 1000, 2)
65
+
66
  def get_timing_summary(self) -> Dict:
67
  """Get a summary of all timing data"""
68
  total_time = self.get_total_time()
69
  return {
70
+ "total_time_ms": total_time,
71
+ "step_times": self.step_times,
72
+ "timestamp": datetime.now().isoformat(),
73
  }
src/vector_store.py CHANGED
@@ -2,38 +2,34 @@
2
  Vector Store module for XENO Bot
3
  Handles ChromaDB vector store operations
4
  """
 
 
 
5
  import chromadb
6
  import numpy as np
7
  import torch
8
  from langchain_chroma import Chroma
9
  from sentence_transformers import util
10
- from typing import List, Tuple, Any
11
- from google import genai
12
- from src.config import (
13
- client,
14
- COLLECTION_NAME,
15
- CHROMA_DB_PATH,
16
- RAG_TOP_K,
17
- RAG_MAX_RESULTS,
18
- EMBEDDING_MODEL
19
- )
20
  from src.knowledge_base import get_knowledge_base_data
21
 
22
 
23
  def initialize_vector_store() -> Tuple[chromadb.Collection, Chroma, Any]:
24
  """
25
  Initialize ChromaDB vector store
26
-
27
  Returns:
28
  Tuple of (collection, vector_store, retriever)
29
  """
30
  # Get knowledge base data
31
  documents, metadatas, ids = get_knowledge_base_data()
32
-
33
  # Initialize ChromaDB client
34
  try:
35
  client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
36
-
37
  # Try to get existing collection
38
  try:
39
  collection = client.get_collection(name=COLLECTION_NAME)
@@ -43,30 +39,31 @@ def initialize_vector_store() -> Tuple[chromadb.Collection, Chroma, Any]:
43
  print(f"Creating new ChromaDB collection: {COLLECTION_NAME}")
44
  collection = client.create_collection(name=COLLECTION_NAME)
45
  collection.add(documents=documents, metadatas=metadatas, ids=ids)
46
-
47
  # Create vector store and retriever
48
  vector_store = Chroma(client=client, collection_name=COLLECTION_NAME)
49
  retriever = vector_store.as_retriever(
50
- search_type="similarity",
51
- search_kwargs={"k": RAG_TOP_K}
52
  )
53
-
54
  return collection, vector_store, retriever
55
-
56
  except Exception as e:
57
  print(f"Failed to initialize ChromaDB: {e}")
58
  raise
59
 
60
 
61
- def generate_embeddings(query: str, documents: List[Any], timer=None) -> Tuple[List[float], List[List[float]]]:
 
 
62
  """
63
  Generate embeddings for query and documents
64
-
65
  Args:
66
  query: User query
67
  documents: List of retrieved documents
68
  timer: Optional timer object for tracking
69
-
70
  Returns:
71
  Tuple of (query_embedding, doc_embeddings)
72
  """
@@ -77,38 +74,40 @@ def generate_embeddings(query: str, documents: List[Any], timer=None) -> Tuple[L
77
  return _generate_embeddings_impl(query, documents)
78
 
79
 
80
- def _generate_embeddings_impl(query: str, documents: List[Any]) -> Tuple[List[float], List[List[float]]]:
 
 
81
  """Internal implementation of embedding generation"""
82
  # 1. Update query embedding access
83
- query_result = client.models.embed_content(
84
- model=EMBEDDING_MODEL,
85
- contents=query
86
  )
87
  # The SDK returns an EmbedContentResponse object with an 'embeddings' attribute
88
- query_embedding = query_result.embeddings[0].values
89
-
90
  # 2. Update document embeddings access
91
  doc_contents = [doc.page_content for doc in documents]
92
- doc_results = client.models.embed_content(
93
- model=EMBEDDING_MODEL,
94
- contents=doc_contents
95
  )
96
-
97
  # Map the list of embedding objects to a list of vector values
98
  doc_embeddings = [e.values for e in doc_results.embeddings]
99
-
100
  return query_embedding, doc_embeddings
101
 
102
 
103
- def calculate_similarity(query_embedding: List[float], doc_embeddings: List[List[float]], timer=None) -> List[float]:
 
 
104
  """
105
  Calculate cosine similarity between query and documents
106
-
107
  Args:
108
  query_embedding: Query embedding vector
109
  doc_embeddings: List of document embedding vectors
110
  timer: Optional timer object for tracking
111
-
112
  Returns:
113
  List of cosine similarity scores
114
  """
@@ -119,27 +118,32 @@ def calculate_similarity(query_embedding: List[float], doc_embeddings: List[List
119
  return _calculate_similarity_impl(query_embedding, doc_embeddings)
120
 
121
 
122
- def _calculate_similarity_impl(query_embedding: List[float], doc_embeddings: List[List[float]]) -> List[float]:
 
 
123
  """Internal implementation of similarity calculation"""
124
  cosine_scores = util.cos_sim(
125
- torch.tensor(query_embedding).float(),
126
- torch.tensor(doc_embeddings).float()
127
  )[0].tolist()
128
-
129
  return cosine_scores
130
 
131
 
132
- def process_context(results: List[Any], cosine_scores: List[float],
133
- max_results: int = RAG_MAX_RESULTS, timer=None) -> Tuple[str, List[str], List[Tuple[str, str]]]:
 
 
 
 
134
  """
135
  Process retrieved context and format for LLM
136
-
137
  Args:
138
  results: List of retrieved documents
139
  cosine_scores: List of similarity scores
140
  max_results: Maximum number of results to include
141
  timer: Optional timer object for tracking
142
-
143
  Returns:
144
  Tuple of (formatted_context, source_ids, knowledge_pairs)
145
  """
@@ -150,28 +154,29 @@ def process_context(results: List[Any], cosine_scores: List[float],
150
  return _process_context_impl(results, cosine_scores, max_results)
151
 
152
 
153
- def _process_context_impl(results: List[Any], cosine_scores: List[float],
154
- max_results: int) -> Tuple[str, List[str], List[Tuple[str, str]]]:
 
155
  """Internal implementation of context processing"""
156
  sorted_indices = np.argsort(cosine_scores)[::-1][:max_results]
157
-
158
  formatted_context = ""
159
  source_ids = []
160
  knowledge_pairs = []
161
-
162
  for i, idx in enumerate(sorted_indices, 1):
163
  result = results[idx]
164
- score = cosine_scores[idx]
165
-
166
- question = result.metadata.get('question', 'N/A')
167
- answer = result.metadata.get('content', 'N/A')
168
-
169
  formatted_context += f"Knowledge Entry {i}:\n"
170
  formatted_context += f"Q: {question}\n"
171
  formatted_context += f"A: {answer}\n"
172
  formatted_context += "-" * 40 + "\n"
173
-
174
- source_ids.append(result.metadata.get('id', 'N/A'))
175
  knowledge_pairs.append((question, answer))
176
-
177
  return formatted_context, source_ids, knowledge_pairs
 
2
  Vector Store module for XENO Bot
3
  Handles ChromaDB vector store operations
4
  """
5
+
6
+ from typing import Any, List, Tuple
7
+
8
  import chromadb
9
  import numpy as np
10
  import torch
11
  from langchain_chroma import Chroma
12
  from sentence_transformers import util
13
+
14
+ from src.config import (CHROMA_DB_PATH, COLLECTION_NAME, EMBEDDING_MODEL,
15
+ RAG_MAX_RESULTS, RAG_TOP_K, genai_client)
 
 
 
 
 
 
 
16
  from src.knowledge_base import get_knowledge_base_data
17
 
18
 
19
  def initialize_vector_store() -> Tuple[chromadb.Collection, Chroma, Any]:
20
  """
21
  Initialize ChromaDB vector store
22
+
23
  Returns:
24
  Tuple of (collection, vector_store, retriever)
25
  """
26
  # Get knowledge base data
27
  documents, metadatas, ids = get_knowledge_base_data()
28
+
29
  # Initialize ChromaDB client
30
  try:
31
  client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
32
+
33
  # Try to get existing collection
34
  try:
35
  collection = client.get_collection(name=COLLECTION_NAME)
 
39
  print(f"Creating new ChromaDB collection: {COLLECTION_NAME}")
40
  collection = client.create_collection(name=COLLECTION_NAME)
41
  collection.add(documents=documents, metadatas=metadatas, ids=ids)
42
+
43
  # Create vector store and retriever
44
  vector_store = Chroma(client=client, collection_name=COLLECTION_NAME)
45
  retriever = vector_store.as_retriever(
46
+ search_type="similarity", search_kwargs={"k": RAG_TOP_K}
 
47
  )
48
+
49
  return collection, vector_store, retriever
50
+
51
  except Exception as e:
52
  print(f"Failed to initialize ChromaDB: {e}")
53
  raise
54
 
55
 
56
+ def generate_embeddings(
57
+ query: str, documents: List[Any], timer=None
58
+ ) -> Tuple[List[float], List[List[float]]]:
59
  """
60
  Generate embeddings for query and documents
61
+
62
  Args:
63
  query: User query
64
  documents: List of retrieved documents
65
  timer: Optional timer object for tracking
66
+
67
  Returns:
68
  Tuple of (query_embedding, doc_embeddings)
69
  """
 
74
  return _generate_embeddings_impl(query, documents)
75
 
76
 
77
+ def _generate_embeddings_impl(
78
+ query: str, documents: List[Any]
79
+ ) -> Tuple[List[float], List[List[float]]]:
80
  """Internal implementation of embedding generation"""
81
  # 1. Update query embedding access
82
+ query_result = genai_client.models.embed_content(
83
+ model=EMBEDDING_MODEL, contents=query
 
84
  )
85
  # The SDK returns an EmbedContentResponse object with an 'embeddings' attribute
86
+ query_embedding = query_result.embeddings[0].values
87
+
88
  # 2. Update document embeddings access
89
  doc_contents = [doc.page_content for doc in documents]
90
+ doc_results = genai_client.models.embed_content(
91
+ model=EMBEDDING_MODEL, contents=doc_contents
 
92
  )
93
+
94
  # Map the list of embedding objects to a list of vector values
95
  doc_embeddings = [e.values for e in doc_results.embeddings]
96
+
97
  return query_embedding, doc_embeddings
98
 
99
 
100
+ def calculate_similarity(
101
+ query_embedding: List[float], doc_embeddings: List[List[float]], timer=None
102
+ ) -> List[float]:
103
  """
104
  Calculate cosine similarity between query and documents
105
+
106
  Args:
107
  query_embedding: Query embedding vector
108
  doc_embeddings: List of document embedding vectors
109
  timer: Optional timer object for tracking
110
+
111
  Returns:
112
  List of cosine similarity scores
113
  """
 
118
  return _calculate_similarity_impl(query_embedding, doc_embeddings)
119
 
120
 
121
+ def _calculate_similarity_impl(
122
+ query_embedding: List[float], doc_embeddings: List[List[float]]
123
+ ) -> List[float]:
124
  """Internal implementation of similarity calculation"""
125
  cosine_scores = util.cos_sim(
126
+ torch.tensor(query_embedding).float(), torch.tensor(doc_embeddings).float()
 
127
  )[0].tolist()
128
+
129
  return cosine_scores
130
 
131
 
132
+ def process_context(
133
+ results: List[Any],
134
+ cosine_scores: List[float],
135
+ max_results: int = RAG_MAX_RESULTS,
136
+ timer=None,
137
+ ) -> Tuple[str, List[str], List[Tuple[str, str]]]:
138
  """
139
  Process retrieved context and format for LLM
140
+
141
  Args:
142
  results: List of retrieved documents
143
  cosine_scores: List of similarity scores
144
  max_results: Maximum number of results to include
145
  timer: Optional timer object for tracking
146
+
147
  Returns:
148
  Tuple of (formatted_context, source_ids, knowledge_pairs)
149
  """
 
154
  return _process_context_impl(results, cosine_scores, max_results)
155
 
156
 
157
+ def _process_context_impl(
158
+ results: List[Any], cosine_scores: List[float], max_results: int
159
+ ) -> Tuple[str, List[str], List[Tuple[str, str]]]:
160
  """Internal implementation of context processing"""
161
  sorted_indices = np.argsort(cosine_scores)[::-1][:max_results]
162
+
163
  formatted_context = ""
164
  source_ids = []
165
  knowledge_pairs = []
166
+
167
  for i, idx in enumerate(sorted_indices, 1):
168
  result = results[idx]
169
+ cosine_scores[idx]
170
+
171
+ question = result.metadata.get("question", "N/A")
172
+ answer = result.metadata.get("content", "N/A")
173
+
174
  formatted_context += f"Knowledge Entry {i}:\n"
175
  formatted_context += f"Q: {question}\n"
176
  formatted_context += f"A: {answer}\n"
177
  formatted_context += "-" * 40 + "\n"
178
+
179
+ source_ids.append(result.metadata.get("id", "N/A"))
180
  knowledge_pairs.append((question, answer))
181
+
182
  return formatted_context, source_ids, knowledge_pairs
tests/conftest.py CHANGED
@@ -2,16 +2,18 @@
2
  Pytest configuration file
3
  Sets up test environment and fixtures
4
  """
 
5
  import os
6
  import sys
 
 
7
  import pytest
8
- from unittest.mock import Mock, MagicMock, patch, PropertyMock
9
 
10
  # Add src to path
11
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
12
 
13
  # Set mock environment variables before importing any modules
14
- os.environ.setdefault('GEMINI_API_KEY', 'test-api-key-12345')
15
 
16
  # Mock Google Sheets credentials
17
  mock_credentials = {
@@ -24,20 +26,23 @@ mock_credentials = {
24
  "auth_uri": "https://accounts.google.com/o/oauth2/auth",
25
  "token_uri": "https://oauth2.googleapis.com/token",
26
  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
27
- "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test"
28
  }
29
  import json
30
- os.environ.setdefault('GOOGLE_SHEETS_CREDENTIALS', json.dumps(mock_credentials))
 
31
 
32
  # Mock google.oauth2 and gspread modules before src.logger imports them
33
  mock_credentials_class = MagicMock()
34
  mock_creds_instance = MagicMock()
35
- mock_credentials_class.from_service_account_info = Mock(return_value=mock_creds_instance)
 
 
36
 
37
  mock_oauth2 = MagicMock()
38
  mock_oauth2.service_account.Credentials = mock_credentials_class
39
- sys.modules['google.oauth2'] = mock_oauth2
40
- sys.modules['google.oauth2.service_account'] = mock_oauth2.service_account
41
 
42
  mock_gspread = MagicMock()
43
  mock_spreadsheet = MagicMock()
@@ -49,14 +54,15 @@ mock_spreadsheet.add_worksheet = Mock(return_value=mock_worksheet)
49
  mock_client = MagicMock()
50
  mock_client.open = Mock(return_value=mock_spreadsheet)
51
  mock_gspread.authorize = Mock(return_value=mock_client)
52
- sys.modules['gspread'] = mock_gspread
53
 
54
 
55
  @pytest.fixture(autouse=True)
56
  def mock_google_sheets():
57
  """Mock Google Sheets to avoid actual connections during testing"""
58
- with patch('src.logger.response_sheet') as mock_response, \
59
- patch('src.logger.timing_sheet') as mock_timing:
 
60
  mock_response.append_row = Mock()
61
  mock_timing.append_row = Mock()
62
  yield mock_response, mock_timing
@@ -65,20 +71,16 @@ def mock_google_sheets():
65
  @pytest.fixture
66
  def mock_genai():
67
  """Mock Google Generative AI"""
68
- with patch('google.generativeai.configure') as mock_config, \
69
- patch('google.generativeai.GenerativeModel') as mock_model, \
70
- patch('google.generativeai.embed_content') as mock_embed:
71
- yield {
72
- 'configure': mock_config,
73
- 'model': mock_model,
74
- 'embed': mock_embed
75
- }
76
 
77
 
78
  @pytest.fixture
79
  def mock_chromadb():
80
  """Mock ChromaDB client"""
81
- with patch('chromadb.PersistentClient') as mock_client:
82
  mock_collection = Mock()
83
  mock_client.return_value.get_collection.return_value = mock_collection
84
  yield mock_client
@@ -87,7 +89,7 @@ def mock_chromadb():
87
  @pytest.fixture
88
  def mock_sqlite():
89
  """Mock SQLite connections for memory"""
90
- with patch('sqlite3.connect') as mock_connect:
91
  mock_conn = Mock()
92
  mock_connect.return_value = mock_conn
93
  yield mock_conn
@@ -97,21 +99,42 @@ def mock_sqlite():
97
  def sample_documents():
98
  """Provide sample documents for testing"""
99
  doc1 = Mock()
100
- doc1.page_content = "Question: How do I create an account?\nAnswer: Visit our website."
 
 
101
  doc1.metadata = {
102
- 'id': 'KB001',
103
- 'question': 'How do I create an account?',
104
- 'content': 'Visit our website.',
105
- 'section': 'Account Management'
106
  }
107
-
108
  doc2 = Mock()
109
  doc2.page_content = "Question: What are the fees?\nAnswer: 1% per transaction."
110
  doc2.metadata = {
111
- 'id': 'KB002',
112
- 'question': 'What are the fees?',
113
- 'content': '1% per transaction.',
114
- 'section': 'Fees'
115
  }
116
-
117
  return [doc1, doc2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  Pytest configuration file
3
  Sets up test environment and fixtures
4
  """
5
+
6
  import os
7
  import sys
8
+ from unittest.mock import MagicMock, Mock, patch
9
+
10
  import pytest
 
11
 
12
  # Add src to path
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14
 
15
  # Set mock environment variables before importing any modules
16
+ os.environ.setdefault("GEMINI_API_KEY", "test-api-key-12345")
17
 
18
  # Mock Google Sheets credentials
19
  mock_credentials = {
 
26
  "auth_uri": "https://accounts.google.com/o/oauth2/auth",
27
  "token_uri": "https://oauth2.googleapis.com/token",
28
  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
29
+ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test",
30
  }
31
  import json
32
+
33
+ os.environ.setdefault("GOOGLE_SHEETS_CREDENTIALS", json.dumps(mock_credentials))
34
 
35
  # Mock google.oauth2 and gspread modules before src.logger imports them
36
  mock_credentials_class = MagicMock()
37
  mock_creds_instance = MagicMock()
38
+ mock_credentials_class.from_service_account_info = Mock(
39
+ return_value=mock_creds_instance
40
+ )
41
 
42
  mock_oauth2 = MagicMock()
43
  mock_oauth2.service_account.Credentials = mock_credentials_class
44
+ sys.modules["google.oauth2"] = mock_oauth2
45
+ sys.modules["google.oauth2.service_account"] = mock_oauth2.service_account
46
 
47
  mock_gspread = MagicMock()
48
  mock_spreadsheet = MagicMock()
 
54
  mock_client = MagicMock()
55
  mock_client.open = Mock(return_value=mock_spreadsheet)
56
  mock_gspread.authorize = Mock(return_value=mock_client)
57
+ sys.modules["gspread"] = mock_gspread
58
 
59
 
60
  @pytest.fixture(autouse=True)
61
  def mock_google_sheets():
62
  """Mock Google Sheets to avoid actual connections during testing"""
63
+ with patch("src.logger.response_sheet") as mock_response, patch(
64
+ "src.logger.timing_sheet"
65
+ ) as mock_timing:
66
  mock_response.append_row = Mock()
67
  mock_timing.append_row = Mock()
68
  yield mock_response, mock_timing
 
71
  @pytest.fixture
72
  def mock_genai():
73
  """Mock Google Generative AI"""
74
+ with patch("google.generativeai.configure") as mock_config, patch(
75
+ "google.generativeai.GenerativeModel"
76
+ ) as mock_model, patch("google.generativeai.embed_content") as mock_embed:
77
+ yield {"configure": mock_config, "model": mock_model, "embed": mock_embed}
 
 
 
 
78
 
79
 
80
  @pytest.fixture
81
  def mock_chromadb():
82
  """Mock ChromaDB client"""
83
+ with patch("chromadb.PersistentClient") as mock_client:
84
  mock_collection = Mock()
85
  mock_client.return_value.get_collection.return_value = mock_collection
86
  yield mock_client
 
89
  @pytest.fixture
90
  def mock_sqlite():
91
  """Mock SQLite connections for memory"""
92
+ with patch("sqlite3.connect") as mock_connect:
93
  mock_conn = Mock()
94
  mock_connect.return_value = mock_conn
95
  yield mock_conn
 
99
  def sample_documents():
100
  """Provide sample documents for testing"""
101
  doc1 = Mock()
102
+ doc1.page_content = (
103
+ "Question: How do I create an account?\nAnswer: Visit our website."
104
+ )
105
  doc1.metadata = {
106
+ "id": "KB001",
107
+ "question": "How do I create an account?",
108
+ "content": "Visit our website.",
109
+ "section": "Account Management",
110
  }
111
+
112
  doc2 = Mock()
113
  doc2.page_content = "Question: What are the fees?\nAnswer: 1% per transaction."
114
  doc2.metadata = {
115
+ "id": "KB002",
116
+ "question": "What are the fees?",
117
+ "content": "1% per transaction.",
118
+ "section": "Fees",
119
  }
120
+
121
  return [doc1, doc2]
122
+
123
+
124
+ @pytest.fixture
125
+ def mock_genai_client():
126
+ """Mock Google Generative AI client with new SDK structure"""
127
+ with patch("src.config.genai_client") as mock_client:
128
+ # Mock generate_content for LLM
129
+ mock_generate_response = Mock()
130
+ mock_generate_response.text = "Test response from LLM"
131
+ mock_client.models.generate_content.return_value = mock_generate_response
132
+
133
+ # Mock embed_content for embeddings
134
+ mock_embedding = Mock()
135
+ mock_embedding.values = [0.1, 0.2, 0.3]
136
+ mock_embed_response = Mock()
137
+ mock_embed_response.embeddings = [mock_embedding]
138
+ mock_client.models.embed_content.return_value = mock_embed_response
139
+
140
+ yield mock_client
tests/test_app.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for app module
3
+ Tests main orchestration logic
4
+ """
5
+
6
+ import unittest
7
+ from unittest.mock import MagicMock, Mock, patch
8
+
9
+ from app import get_context_and_answer
10
+
11
+
12
+ class TestApp(unittest.TestCase):
13
+ """Test cases for app module"""
14
+
15
+ def setUp(self):
16
+ """Set up test fixtures"""
17
+ self.message = "How do I create an account?"
18
+ self.history = [["Previous question", "Previous answer"]]
19
+ self.session_id = "test-session-123"
20
+ self.mock_intent_classifier = Mock()
21
+ self.mock_retriever = Mock()
22
+
23
+ @patch("app.log_timing_data")
24
+ @patch("app.log_response")
25
+ @patch("app.update_memory")
26
+ @patch("app.retrieve_memory")
27
+ @patch("app.create_session_config")
28
+ def test_get_context_and_answer_simple_intent(
29
+ self,
30
+ mock_session_config,
31
+ mock_retrieve_memory,
32
+ mock_update_memory,
33
+ mock_log_response,
34
+ mock_log_timing,
35
+ ):
36
+ """Test get_context_and_answer with simple intent (greeting)"""
37
+ # Setup mocks
38
+ mock_session_config.return_value = {"session_id": self.session_id}
39
+ mock_retrieve_memory.return_value = []
40
+ self.mock_intent_classifier.classify_intent.return_value = (
41
+ "greeting",
42
+ "Hello! How can I help you?",
43
+ )
44
+
45
+ # Call function
46
+ answer = get_context_and_answer(
47
+ "Hello",
48
+ self.history,
49
+ self.session_id,
50
+ self.mock_intent_classifier,
51
+ self.mock_retriever,
52
+ )
53
+
54
+ # Verify intent was classified
55
+ self.mock_intent_classifier.classify_intent.assert_called_once_with("Hello")
56
+
57
+ # Should not use retriever for simple intent
58
+ self.mock_retriever.invoke.assert_not_called()
59
+
60
+ # Verify response
61
+ self.assertEqual(answer, "Hello! How can I help you?")
62
+
63
+ # Verify memory was updated
64
+ mock_update_memory.assert_called_once()
65
+
66
+ # Verify logging
67
+ mock_log_response.assert_called_once()
68
+ mock_log_timing.assert_called_once()
69
+
70
+ @patch("app.generate_xeno_response")
71
+ @patch("app.process_context")
72
+ @patch("app.generate_embeddings")
73
+ @patch("app.log_timing_data")
74
+ @patch("app.log_response")
75
+ @patch("app.update_memory")
76
+ @patch("app.retrieve_memory")
77
+ @patch("app.create_session_config")
78
+ def test_get_context_and_answer_query_intent(
79
+ self,
80
+ mock_session_config,
81
+ mock_retrieve_memory,
82
+ mock_update_memory,
83
+ mock_log_response,
84
+ mock_log_timing,
85
+ mock_generate_embeddings,
86
+ mock_process_context,
87
+ mock_generate_response,
88
+ ):
89
+ """Test get_context_and_answer with query intent"""
90
+ # Setup mocks
91
+ mock_session_config.return_value = {"session_id": self.session_id}
92
+ mock_retrieve_memory.return_value = []
93
+ self.mock_intent_classifier.classify_intent.return_value = ("query", None)
94
+
95
+ # Mock retriever
96
+ mock_doc = Mock()
97
+ mock_doc.page_content = "Test content"
98
+ mock_doc.metadata = {"id": "KB001", "question": "Q", "content": "A"}
99
+ self.mock_retriever.invoke.return_value = [mock_doc]
100
+
101
+ # Mock embeddings
102
+ mock_generate_embeddings.return_value = (
103
+ [0.1, 0.2, 0.3], # query embedding
104
+ [[0.2, 0.3, 0.4]], # doc embeddings
105
+ )
106
+
107
+ # Mock context processing
108
+ mock_process_context.return_value = (
109
+ "Formatted context",
110
+ ["KB001"],
111
+ [("Q", "A")],
112
+ )
113
+
114
+ # Mock LLM response
115
+ mock_generate_response.return_value = "Generated answer"
116
+
117
+ # Call function
118
+ answer = get_context_and_answer(
119
+ self.message,
120
+ self.history,
121
+ self.session_id,
122
+ self.mock_intent_classifier,
123
+ self.mock_retriever,
124
+ )
125
+
126
+ # Verify RAG pipeline was executed
127
+ self.mock_retriever.invoke.assert_called_once_with(self.message)
128
+ mock_generate_embeddings.assert_called_once()
129
+ mock_process_context.assert_called_once()
130
+ mock_generate_response.assert_called_once()
131
+
132
+ # Verify response
133
+ self.assertEqual(answer, "Generated answer")
134
+
135
+ # Verify logging
136
+ mock_log_response.assert_called_once()
137
+ mock_log_timing.assert_called_once()
138
+
139
+ @patch("app.log_timing_data")
140
+ @patch("app.log_response")
141
+ @patch("app.update_memory")
142
+ @patch("app.retrieve_memory")
143
+ @patch("app.create_session_config")
144
+ def test_get_context_and_answer_short_message(
145
+ self,
146
+ mock_session_config,
147
+ mock_retrieve_memory,
148
+ mock_update_memory,
149
+ mock_log_response,
150
+ mock_log_timing,
151
+ ):
152
+ """Test get_context_and_answer with very short message"""
153
+ # Setup mocks
154
+ mock_session_config.return_value = {"session_id": self.session_id}
155
+ mock_retrieve_memory.return_value = []
156
+ self.mock_intent_classifier.classify_intent.return_value = ("query", None)
157
+
158
+ # Call function with short message
159
+ answer = get_context_and_answer(
160
+ "Hi",
161
+ self.history,
162
+ self.session_id,
163
+ self.mock_intent_classifier,
164
+ self.mock_retriever,
165
+ )
166
+
167
+ # Should return a request for more details
168
+ self.assertIn("more details", answer)
169
+
170
+ # Should not invoke retriever
171
+ self.mock_retriever.invoke.assert_not_called()
172
+
173
+ @patch("app.generate_embeddings")
174
+ @patch("app.log_timing_data")
175
+ @patch("app.log_response")
176
+ @patch("app.update_memory")
177
+ @patch("app.retrieve_memory")
178
+ @patch("app.create_session_config")
179
+ def test_get_context_and_answer_low_similarity(
180
+ self,
181
+ mock_session_config,
182
+ mock_retrieve_memory,
183
+ mock_update_memory,
184
+ mock_log_response,
185
+ mock_log_timing,
186
+ mock_generate_embeddings,
187
+ ):
188
+ """Test get_context_and_answer with low similarity score"""
189
+ # Setup mocks
190
+ mock_session_config.return_value = {"session_id": self.session_id}
191
+ mock_retrieve_memory.return_value = []
192
+ self.mock_intent_classifier.classify_intent.return_value = ("query", None)
193
+
194
+ # Mock retriever
195
+ mock_doc = Mock()
196
+ mock_doc.page_content = "Test content"
197
+ self.mock_retriever.invoke.return_value = [mock_doc]
198
+
199
+ # Mock embeddings with low similarity
200
+ mock_generate_embeddings.return_value = (
201
+ [0.1, 0.2, 0.3],
202
+ [[1.0, 0.0, 0.0]], # Will result in low cosine score
203
+ )
204
+
205
+ # Call function
206
+ answer = get_context_and_answer(
207
+ "Some random question",
208
+ self.history,
209
+ self.session_id,
210
+ self.mock_intent_classifier,
211
+ self.mock_retriever,
212
+ )
213
+
214
+ # Should return "couldn't find" message
215
+ self.assertIn("couldn't find", answer)
216
+
217
+ @patch("app.log_timing_data")
218
+ @patch("app.log_response")
219
+ @patch("app.update_memory")
220
+ @patch("app.retrieve_memory")
221
+ @patch("app.create_session_config")
222
+ def test_get_context_and_answer_rag_error(
223
+ self,
224
+ mock_session_config,
225
+ mock_retrieve_memory,
226
+ mock_update_memory,
227
+ mock_log_response,
228
+ mock_log_timing,
229
+ ):
230
+ """Test get_context_and_answer handles RAG errors gracefully"""
231
+ # Setup mocks
232
+ mock_session_config.return_value = {"session_id": self.session_id}
233
+ mock_retrieve_memory.return_value = []
234
+ self.mock_intent_classifier.classify_intent.return_value = ("query", None)
235
+
236
+ # Mock retriever to raise exception
237
+ self.mock_retriever.invoke.side_effect = Exception("Database error")
238
+
239
+ # Call function
240
+ answer = get_context_and_answer(
241
+ self.message,
242
+ self.history,
243
+ self.session_id,
244
+ self.mock_intent_classifier,
245
+ self.mock_retriever,
246
+ )
247
+
248
+ # Should return technical issue message
249
+ self.assertIn("technical issue", answer)
250
+
251
+ # Verify error was logged
252
+ mock_log_timing.assert_called_once()
253
+ call_kwargs = mock_log_timing.call_args[1]
254
+ self.assertIsNotNone(call_kwargs.get("error_step"))
255
+
256
+ @patch("app.log_timing_data")
257
+ @patch("app.update_memory")
258
+ @patch("app.retrieve_memory")
259
+ @patch("app.create_session_config")
260
+ def test_get_context_and_answer_main_error(
261
+ self,
262
+ mock_session_config,
263
+ mock_retrieve_memory,
264
+ mock_update_memory,
265
+ mock_log_timing,
266
+ ):
267
+ """Test get_context_and_answer handles main pipeline errors"""
268
+ # Setup mocks
269
+ mock_session_config.return_value = {"session_id": self.session_id}
270
+ mock_retrieve_memory.side_effect = Exception("Memory error")
271
+
272
+ # Call function
273
+ answer = get_context_and_answer(
274
+ self.message,
275
+ self.history,
276
+ self.session_id,
277
+ self.mock_intent_classifier,
278
+ self.mock_retriever,
279
+ )
280
+
281
+ # Should return error message
282
+ self.assertIn("error", answer)
283
+
284
+ # Verify error was logged
285
+ mock_log_timing.assert_called_once()
286
+
287
+ @patch("app.generate_xeno_response")
288
+ @patch("app.process_context")
289
+ @patch("app.generate_embeddings")
290
+ @patch("app.log_timing_data")
291
+ @patch("app.log_response")
292
+ @patch("app.update_memory")
293
+ @patch("app.retrieve_memory")
294
+ @patch("app.create_session_config")
295
+ def test_get_context_and_answer_with_chat_history(
296
+ self,
297
+ mock_session_config,
298
+ mock_retrieve_memory,
299
+ mock_update_memory,
300
+ mock_log_response,
301
+ mock_log_timing,
302
+ mock_generate_embeddings,
303
+ mock_process_context,
304
+ mock_generate_response,
305
+ ):
306
+ """Test get_context_and_answer passes chat history to LLM"""
307
+ # Setup mocks
308
+ mock_session_config.return_value = {"session_id": self.session_id}
309
+ chat_history = [
310
+ {"role": "user", "content": "Previous question"},
311
+ {"role": "assistant", "content": "Previous answer"},
312
+ ]
313
+ mock_retrieve_memory.return_value = chat_history
314
+ self.mock_intent_classifier.classify_intent.return_value = ("query", None)
315
+
316
+ # Mock retriever
317
+ mock_doc = Mock()
318
+ mock_doc.page_content = "Test content"
319
+ mock_doc.metadata = {"id": "KB001", "question": "Q", "content": "A"}
320
+ self.mock_retriever.invoke.return_value = [mock_doc]
321
+
322
+ # Mock embeddings
323
+ mock_generate_embeddings.return_value = ([0.1, 0.2], [[0.9, 0.1]])
324
+
325
+ # Mock context processing
326
+ mock_process_context.return_value = ("Context", ["KB001"], [("Q", "A")])
327
+
328
+ # Mock LLM response
329
+ mock_generate_response.return_value = "Answer with context"
330
+
331
+ # Call function
332
+ answer = get_context_and_answer(
333
+ self.message,
334
+ self.history,
335
+ self.session_id,
336
+ self.mock_intent_classifier,
337
+ self.mock_retriever,
338
+ )
339
+
340
+ # Verify chat history was passed to LLM
341
+ mock_generate_response.assert_called_once()
342
+ call_args = mock_generate_response.call_args[0]
343
+ self.assertEqual(call_args[2], chat_history)
344
+
345
+ @patch("app.PipelineTimer")
346
+ @patch("app.generate_xeno_response")
347
+ @patch("app.process_context")
348
+ @patch("app.generate_embeddings")
349
+ @patch("app.log_timing_data")
350
+ @patch("app.log_response")
351
+ @patch("app.update_memory")
352
+ @patch("app.retrieve_memory")
353
+ @patch("app.create_session_config")
354
+ def test_get_context_and_answer_timing(
355
+ self,
356
+ mock_session_config,
357
+ mock_retrieve_memory,
358
+ mock_update_memory,
359
+ mock_log_response,
360
+ mock_log_timing,
361
+ mock_generate_embeddings,
362
+ mock_process_context,
363
+ mock_generate_response,
364
+ mock_timer_class,
365
+ ):
366
+ """Test get_context_and_answer uses PipelineTimer correctly"""
367
+ # Setup mocks
368
+ mock_timer = Mock()
369
+ mock_timer.time_step = MagicMock()
370
+ mock_timer.time_step.return_value.__enter__ = Mock()
371
+ mock_timer.time_step.return_value.__exit__ = Mock()
372
+ mock_timer.get_timing_summary.return_value = {"total": 1.5}
373
+ mock_timer_class.return_value = mock_timer
374
+
375
+ mock_session_config.return_value = {"session_id": self.session_id}
376
+ mock_retrieve_memory.return_value = []
377
+ self.mock_intent_classifier.classify_intent.return_value = ("query", None)
378
+
379
+ # Mock retriever
380
+ mock_doc = Mock()
381
+ mock_doc.page_content = "Test"
382
+ mock_doc.metadata = {"id": "KB001", "question": "Q", "content": "A"}
383
+ self.mock_retriever.invoke.return_value = [mock_doc]
384
+
385
+ # Mock embeddings
386
+ mock_generate_embeddings.return_value = ([0.1], [[0.9]])
387
+ mock_process_context.return_value = ("Context", ["KB001"], [("Q", "A")])
388
+ mock_generate_response.return_value = "Answer"
389
+
390
+ # Call function
391
+ get_context_and_answer(
392
+ self.message,
393
+ self.history,
394
+ self.session_id,
395
+ self.mock_intent_classifier,
396
+ self.mock_retriever,
397
+ )
398
+
399
+ # Verify timer was used
400
+ mock_timer.reset.assert_called_once()
401
+ mock_timer.get_timing_summary.assert_called()
402
+
403
+ # Verify timing was logged
404
+ mock_log_timing.assert_called_once()
405
+ call_args = mock_log_timing.call_args[0]
406
+ # Second positional argument is session_id, third is timing_summary
407
+ self.assertIn("total", call_args[2])
408
+
409
+
410
+ if __name__ == "__main__":
411
+ unittest.main()
tests/test_intent_classifier.py CHANGED
@@ -2,25 +2,27 @@
2
  Unit tests for intent_classifier module
3
  Tests the IntentClassifier class
4
  """
 
5
  import unittest
6
  from unittest.mock import Mock
 
7
  from src.intent_classifier import IntentClassifier
8
 
9
 
10
  class TestIntentClassifier(unittest.TestCase):
11
  """Test cases for IntentClassifier class"""
12
-
13
  def setUp(self):
14
  """Set up test fixtures"""
15
  self.classifier = IntentClassifier()
16
-
17
  def test_initialization(self):
18
  """Test classifier initialization"""
19
  self.assertIsNotNone(self.classifier.intent_patterns)
20
- self.assertIn('greeting', self.classifier.intent_patterns)
21
- self.assertIn('thanks', self.classifier.intent_patterns)
22
- self.assertIn('goodbye', self.classifier.intent_patterns)
23
-
24
  def test_classify_greeting(self):
25
  """Test classification of greeting messages"""
26
  test_cases = [
@@ -29,15 +31,15 @@ class TestIntentClassifier(unittest.TestCase):
29
  "Hey there",
30
  "good morning",
31
  "Good afternoon!",
32
- "how are you"
33
  ]
34
-
35
  for message in test_cases:
36
  intent, response = self.classifier.classify_intent(message)
37
- self.assertEqual(intent, 'greeting', f"Failed for message: {message}")
38
  self.assertIsInstance(response, str)
39
  self.assertGreater(len(response), 0)
40
-
41
  def test_classify_thanks(self):
42
  """Test classification of thank you messages"""
43
  test_cases = [
@@ -47,15 +49,15 @@ class TestIntentClassifier(unittest.TestCase):
47
  "thx",
48
  "I appreciate it",
49
  "thanks a lot",
50
- "thank you so much"
51
  ]
52
-
53
  for message in test_cases:
54
  intent, response = self.classifier.classify_intent(message)
55
- self.assertEqual(intent, 'thanks', f"Failed for message: {message}")
56
  self.assertIsInstance(response, str)
57
  self.assertGreater(len(response), 0)
58
-
59
  def test_classify_goodbye(self):
60
  """Test classification of goodbye messages"""
61
  test_cases = [
@@ -65,78 +67,82 @@ class TestIntentClassifier(unittest.TestCase):
65
  "farewell",
66
  "take care",
67
  "have a good day",
68
- "talk to you later"
69
  ]
70
-
71
  for message in test_cases:
72
  intent, response = self.classifier.classify_intent(message)
73
- self.assertEqual(intent, 'goodbye', f"Failed for message: {message}")
74
  self.assertIsInstance(response, str)
75
  self.assertGreater(len(response), 0)
76
-
77
  def test_classify_query(self):
78
  """Test classification of query messages"""
79
  test_cases = [
80
  "How do I open an account?",
81
  "What are the transaction fees?",
82
  "Can you help me with my balance?",
83
- "Tell me about XENO services"
84
  ]
85
-
86
  for message in test_cases:
87
  intent, response = self.classifier.classify_intent(message)
88
- self.assertEqual(intent, 'query', f"Failed for message: {message}")
89
- self.assertEqual(response, '')
90
-
91
  def test_case_insensitivity(self):
92
  """Test that classification is case insensitive"""
93
  messages = [
94
- ("HI", 'greeting'),
95
- ("THANK YOU", 'thanks'),
96
- ("BYE", 'goodbye'),
97
- ("Hi There", 'greeting')
98
  ]
99
-
100
  for message, expected_intent in messages:
101
  intent, _ = self.classifier.classify_intent(message)
102
  self.assertEqual(intent, expected_intent)
103
-
104
  def test_with_timer(self):
105
  """Test classification with timer object"""
106
  mock_timer = Mock()
107
  mock_timer.time_step = Mock()
108
  mock_timer.time_step.return_value.__enter__ = Mock()
109
  mock_timer.time_step.return_value.__exit__ = Mock()
110
-
111
  intent, response = self.classifier.classify_intent("hello", timer=mock_timer)
112
-
113
- self.assertEqual(intent, 'greeting')
114
  mock_timer.time_step.assert_called_once_with("intent_classification")
115
-
116
  def test_is_simple_intent(self):
117
  """Test is_simple_intent method"""
118
- self.assertTrue(self.classifier.is_simple_intent('greeting'))
119
- self.assertTrue(self.classifier.is_simple_intent('thanks'))
120
- self.assertFalse(self.classifier.is_simple_intent('goodbye'))
121
- self.assertFalse(self.classifier.is_simple_intent('query'))
122
-
123
  def test_add_intent(self):
124
  """Test adding a new intent"""
125
- patterns = [r'\b(test|testing)\b']
126
  responses = ["This is a test response"]
127
-
128
- self.classifier.add_intent('test_intent', patterns, responses)
129
-
130
  # Verify intent was added
131
- self.assertIn('test_intent', self.classifier.intent_patterns)
132
- self.assertEqual(self.classifier.intent_patterns['test_intent']['patterns'], patterns)
133
- self.assertEqual(self.classifier.intent_patterns['test_intent']['responses'], responses)
134
-
 
 
 
 
135
  # Test classification with new intent
136
  intent, response = self.classifier.classify_intent("testing")
137
- self.assertEqual(intent, 'test_intent')
138
  self.assertEqual(response, "This is a test response")
139
-
140
  def test_response_variety(self):
141
  """Test that responses vary (random selection)"""
142
  # Multiple calls might return different responses
@@ -144,26 +150,26 @@ class TestIntentClassifier(unittest.TestCase):
144
  for _ in range(20):
145
  _, response = self.classifier.classify_intent("hello")
146
  responses.add(response)
147
-
148
  # Should have at least 1 response (could be more if random varies)
149
  self.assertGreater(len(responses), 0)
150
-
151
  def test_empty_message(self):
152
  """Test classification of empty or whitespace messages"""
153
  test_cases = ["", " ", "\n", "\t"]
154
-
155
  for message in test_cases:
156
  intent, response = self.classifier.classify_intent(message)
157
- self.assertEqual(intent, 'query')
158
- self.assertEqual(response, '')
159
-
160
  def test_mixed_intent_message(self):
161
  """Test messages that might match multiple patterns"""
162
  # "hi thank you" should match greeting (first match wins)
163
  intent, response = self.classifier.classify_intent("hi thank you")
164
  # Should match the first pattern it encounters
165
- self.assertIn(intent, ['greeting', 'thanks'])
166
 
167
 
168
- if __name__ == '__main__':
169
  unittest.main()
 
2
  Unit tests for intent_classifier module
3
  Tests the IntentClassifier class
4
  """
5
+
6
  import unittest
7
  from unittest.mock import Mock
8
+
9
  from src.intent_classifier import IntentClassifier
10
 
11
 
12
  class TestIntentClassifier(unittest.TestCase):
13
  """Test cases for IntentClassifier class"""
14
+
15
  def setUp(self):
16
  """Set up test fixtures"""
17
  self.classifier = IntentClassifier()
18
+
19
  def test_initialization(self):
20
  """Test classifier initialization"""
21
  self.assertIsNotNone(self.classifier.intent_patterns)
22
+ self.assertIn("greeting", self.classifier.intent_patterns)
23
+ self.assertIn("thanks", self.classifier.intent_patterns)
24
+ self.assertIn("goodbye", self.classifier.intent_patterns)
25
+
26
  def test_classify_greeting(self):
27
  """Test classification of greeting messages"""
28
  test_cases = [
 
31
  "Hey there",
32
  "good morning",
33
  "Good afternoon!",
34
+ "how are you",
35
  ]
36
+
37
  for message in test_cases:
38
  intent, response = self.classifier.classify_intent(message)
39
+ self.assertEqual(intent, "greeting", f"Failed for message: {message}")
40
  self.assertIsInstance(response, str)
41
  self.assertGreater(len(response), 0)
42
+
43
  def test_classify_thanks(self):
44
  """Test classification of thank you messages"""
45
  test_cases = [
 
49
  "thx",
50
  "I appreciate it",
51
  "thanks a lot",
52
+ "thank you so much",
53
  ]
54
+
55
  for message in test_cases:
56
  intent, response = self.classifier.classify_intent(message)
57
+ self.assertEqual(intent, "thanks", f"Failed for message: {message}")
58
  self.assertIsInstance(response, str)
59
  self.assertGreater(len(response), 0)
60
+
61
  def test_classify_goodbye(self):
62
  """Test classification of goodbye messages"""
63
  test_cases = [
 
67
  "farewell",
68
  "take care",
69
  "have a good day",
70
+ "talk to you later",
71
  ]
72
+
73
  for message in test_cases:
74
  intent, response = self.classifier.classify_intent(message)
75
+ self.assertEqual(intent, "goodbye", f"Failed for message: {message}")
76
  self.assertIsInstance(response, str)
77
  self.assertGreater(len(response), 0)
78
+
79
  def test_classify_query(self):
80
  """Test classification of query messages"""
81
  test_cases = [
82
  "How do I open an account?",
83
  "What are the transaction fees?",
84
  "Can you help me with my balance?",
85
+ "Tell me about XENO services",
86
  ]
87
+
88
  for message in test_cases:
89
  intent, response = self.classifier.classify_intent(message)
90
+ self.assertEqual(intent, "query", f"Failed for message: {message}")
91
+ self.assertEqual(response, "")
92
+
93
  def test_case_insensitivity(self):
94
  """Test that classification is case insensitive"""
95
  messages = [
96
+ ("HI", "greeting"),
97
+ ("THANK YOU", "thanks"),
98
+ ("BYE", "goodbye"),
99
+ ("Hi There", "greeting"),
100
  ]
101
+
102
  for message, expected_intent in messages:
103
  intent, _ = self.classifier.classify_intent(message)
104
  self.assertEqual(intent, expected_intent)
105
+
106
  def test_with_timer(self):
107
  """Test classification with timer object"""
108
  mock_timer = Mock()
109
  mock_timer.time_step = Mock()
110
  mock_timer.time_step.return_value.__enter__ = Mock()
111
  mock_timer.time_step.return_value.__exit__ = Mock()
112
+
113
  intent, response = self.classifier.classify_intent("hello", timer=mock_timer)
114
+
115
+ self.assertEqual(intent, "greeting")
116
  mock_timer.time_step.assert_called_once_with("intent_classification")
117
+
118
  def test_is_simple_intent(self):
119
  """Test is_simple_intent method"""
120
+ self.assertTrue(self.classifier.is_simple_intent("greeting"))
121
+ self.assertTrue(self.classifier.is_simple_intent("thanks"))
122
+ self.assertFalse(self.classifier.is_simple_intent("goodbye"))
123
+ self.assertFalse(self.classifier.is_simple_intent("query"))
124
+
125
  def test_add_intent(self):
126
  """Test adding a new intent"""
127
+ patterns = [r"\b(test|testing)\b"]
128
  responses = ["This is a test response"]
129
+
130
+ self.classifier.add_intent("test_intent", patterns, responses)
131
+
132
  # Verify intent was added
133
+ self.assertIn("test_intent", self.classifier.intent_patterns)
134
+ self.assertEqual(
135
+ self.classifier.intent_patterns["test_intent"]["patterns"], patterns
136
+ )
137
+ self.assertEqual(
138
+ self.classifier.intent_patterns["test_intent"]["responses"], responses
139
+ )
140
+
141
  # Test classification with new intent
142
  intent, response = self.classifier.classify_intent("testing")
143
+ self.assertEqual(intent, "test_intent")
144
  self.assertEqual(response, "This is a test response")
145
+
146
  def test_response_variety(self):
147
  """Test that responses vary (random selection)"""
148
  # Multiple calls might return different responses
 
150
  for _ in range(20):
151
  _, response = self.classifier.classify_intent("hello")
152
  responses.add(response)
153
+
154
  # Should have at least 1 response (could be more if random varies)
155
  self.assertGreater(len(responses), 0)
156
+
157
  def test_empty_message(self):
158
  """Test classification of empty or whitespace messages"""
159
  test_cases = ["", " ", "\n", "\t"]
160
+
161
  for message in test_cases:
162
  intent, response = self.classifier.classify_intent(message)
163
+ self.assertEqual(intent, "query")
164
+ self.assertEqual(response, "")
165
+
166
  def test_mixed_intent_message(self):
167
  """Test messages that might match multiple patterns"""
168
  # "hi thank you" should match greeting (first match wins)
169
  intent, response = self.classifier.classify_intent("hi thank you")
170
  # Should match the first pattern it encounters
171
+ self.assertIn(intent, ["greeting", "thanks"])
172
 
173
 
174
+ if __name__ == "__main__":
175
  unittest.main()
tests/test_interface.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for interface module
3
+ Tests Gradio interface functionality
4
+ """
5
+
6
+ import unittest
7
+ import uuid
8
+ from unittest.mock import MagicMock, Mock, patch
9
+
10
+ from src.interface import create_interface, respond
11
+
12
+
13
+ class TestInterface(unittest.TestCase):
14
+ """Test cases for interface module"""
15
+
16
+ def setUp(self):
17
+ """Set up test fixtures"""
18
+ self.message = "How do I create an account?"
19
+ self.history = [["Previous question", "Previous answer"]]
20
+ self.session_id = str(uuid.uuid4())
21
+ self.mock_intent_classifier = Mock()
22
+ self.mock_retriever = Mock()
23
+
24
+ @patch("app.get_context_and_answer")
25
+ def test_respond_with_session_id(self, mock_get_answer):
26
+ """Test respond function with existing session ID"""
27
+ mock_get_answer.return_value = "You can create an account by visiting our website."
28
+
29
+ result_msg, result_history = respond(
30
+ self.message,
31
+ self.history.copy(),
32
+ self.session_id,
33
+ self.mock_intent_classifier,
34
+ self.mock_retriever,
35
+ )
36
+
37
+ # Verify get_context_and_answer was called
38
+ mock_get_answer.assert_called_once()
39
+ call_args = mock_get_answer.call_args[0]
40
+ self.assertEqual(call_args[0], self.message)
41
+ self.assertEqual(call_args[2], self.session_id)
42
+
43
+ # Check return values
44
+ self.assertEqual(result_msg, "")
45
+ self.assertEqual(len(result_history), 2)
46
+ self.assertEqual(result_history[-1][0], self.message)
47
+ self.assertEqual(
48
+ result_history[-1][1],
49
+ "You can create an account by visiting our website.",
50
+ )
51
+
52
+ @patch("app.get_context_and_answer")
53
+ def test_respond_without_session_id(self, mock_get_answer):
54
+ """Test respond function generates session ID when none provided"""
55
+ mock_get_answer.return_value = "Response"
56
+
57
+ result_msg, result_history = respond(
58
+ self.message,
59
+ [],
60
+ None,
61
+ self.mock_intent_classifier,
62
+ self.mock_retriever,
63
+ )
64
+
65
+ # Should have called with a generated session ID
66
+ self.assertEqual(mock_get_answer.call_count, 1)
67
+ call_args = mock_get_answer.call_args[0]
68
+ generated_session_id = call_args[2]
69
+
70
+ # Verify it's a valid UUID
71
+ try:
72
+ uuid.UUID(generated_session_id)
73
+ valid_uuid = True
74
+ except ValueError:
75
+ valid_uuid = False
76
+
77
+ self.assertTrue(valid_uuid)
78
+
79
+ # Check return values
80
+ self.assertEqual(result_msg, "")
81
+ self.assertEqual(len(result_history), 1)
82
+
83
+ @patch("app.get_context_and_answer")
84
+ def test_respond_with_empty_history(self, mock_get_answer):
85
+ """Test respond function with empty history"""
86
+ mock_get_answer.return_value = "Test response"
87
+
88
+ result_msg, result_history = respond(
89
+ "Test question",
90
+ [],
91
+ self.session_id,
92
+ self.mock_intent_classifier,
93
+ self.mock_retriever,
94
+ )
95
+
96
+ # History should have one entry
97
+ self.assertEqual(len(result_history), 1)
98
+ self.assertEqual(result_history[0][0], "Test question")
99
+ self.assertEqual(result_history[0][1], "Test response")
100
+
101
+ @patch("app.get_context_and_answer")
102
+ def test_respond_preserves_existing_history(self, mock_get_answer):
103
+ """Test respond function preserves existing chat history"""
104
+ mock_get_answer.return_value = "New response"
105
+
106
+ initial_history = [
107
+ ["Question 1", "Answer 1"],
108
+ ["Question 2", "Answer 2"],
109
+ ]
110
+
111
+ result_msg, result_history = respond(
112
+ "Question 3",
113
+ initial_history.copy(),
114
+ self.session_id,
115
+ self.mock_intent_classifier,
116
+ self.mock_retriever,
117
+ )
118
+
119
+ # Should have 3 entries now
120
+ self.assertEqual(len(result_history), 3)
121
+ self.assertEqual(result_history[0][0], "Question 1")
122
+ self.assertEqual(result_history[1][0], "Question 2")
123
+ self.assertEqual(result_history[2][0], "Question 3")
124
+
125
+ def test_create_interface_returns_blocks(self):
126
+ """Test create_interface returns Gradio Blocks interface"""
127
+ result = create_interface(self.mock_intent_classifier, self.mock_retriever)
128
+
129
+ # Should return a Gradio Blocks object
130
+ import gradio as gr
131
+ self.assertIsInstance(result, gr.Blocks)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ unittest.main()
tests/test_knowledge_base.py CHANGED
@@ -2,22 +2,22 @@
2
  Unit tests for knowledge_base module
3
  Tests knowledge base loading and preparation
4
  """
5
- import unittest
6
- import pandas as pd
7
  import json
8
- import tempfile
9
  import os
10
- from unittest.mock import patch, Mock
11
- from src.knowledge_base import (
12
- load_knowledge_base,
13
- prepare_documents,
14
- get_knowledge_base_data
15
- )
 
 
16
 
17
 
18
  class TestKnowledgeBase(unittest.TestCase):
19
  """Test cases for knowledge_base module"""
20
-
21
  def setUp(self):
22
  """Set up test fixtures"""
23
  # Create sample knowledge base data
@@ -29,7 +29,7 @@ class TestKnowledgeBase(unittest.TestCase):
29
  "Section": "Account Management",
30
  "Source": "Website",
31
  "Owner": "Support Team",
32
- "Tag": "account"
33
  },
34
  {
35
  "ID": "KB002",
@@ -38,35 +38,33 @@ class TestKnowledgeBase(unittest.TestCase):
38
  "Section": "Fees",
39
  "Source": "Documentation",
40
  "Owner": "Finance Team",
41
- "Tag": "fees"
42
- }
43
  ]
44
-
45
  # Create temporary JSON file
46
  self.temp_file = tempfile.NamedTemporaryFile(
47
- mode='w',
48
- delete=False,
49
- suffix='.json'
50
  )
51
  json.dump(self.sample_data, self.temp_file)
52
  self.temp_file.close()
53
-
54
  def tearDown(self):
55
  """Clean up test fixtures"""
56
  if os.path.exists(self.temp_file.name):
57
  os.unlink(self.temp_file.name)
58
-
59
  def test_load_knowledge_base(self):
60
  """Test loading knowledge base from JSON file"""
61
  df = load_knowledge_base(self.temp_file.name)
62
-
63
  # Check DataFrame structure
64
  self.assertIsInstance(df, pd.DataFrame)
65
  self.assertEqual(len(df), 2)
66
- self.assertIn('ID', df.columns)
67
- self.assertIn('Question', df.columns)
68
- self.assertIn('Content', df.columns)
69
-
70
  def test_load_knowledge_base_drops_null_content(self):
71
  """Test that rows with null Content are dropped"""
72
  data_with_null = self.sample_data + [
@@ -74,110 +72,124 @@ class TestKnowledgeBase(unittest.TestCase):
74
  "ID": "KB003",
75
  "Question": "Test question?",
76
  "Content": None,
77
- "Section": "Test"
78
  }
79
  ]
80
-
81
  temp_file_null = tempfile.NamedTemporaryFile(
82
- mode='w',
83
- delete=False,
84
- suffix='.json'
85
  )
86
  json.dump(data_with_null, temp_file_null)
87
  temp_file_null.close()
88
-
89
  try:
90
  df = load_knowledge_base(temp_file_null.name)
91
  # Should only have 2 rows (null Content row dropped)
92
  self.assertEqual(len(df), 2)
93
  finally:
94
  os.unlink(temp_file_null.name)
95
-
96
  def test_prepare_documents(self):
97
  """Test preparing documents for vector store"""
98
  documents, metadatas, ids = prepare_documents(self.sample_data)
99
-
100
  # Check lengths match
101
  self.assertEqual(len(documents), 2)
102
  self.assertEqual(len(metadatas), 2)
103
  self.assertEqual(len(ids), 2)
104
-
105
  # Check document format
106
  self.assertIn("Question:", documents[0])
107
  self.assertIn("Answer:", documents[0])
108
  self.assertIn("How do I create an account?", documents[0])
109
-
110
  # Check metadata structure
111
- self.assertEqual(metadatas[0]['id'], 'KB001')
112
- self.assertEqual(metadatas[0]['question'], 'How do I create an account?')
113
- self.assertEqual(metadatas[0]['section'], 'Account Management')
114
-
115
  # Check IDs
116
- self.assertEqual(ids[0], 'KB001')
117
- self.assertEqual(ids[1], 'KB002')
118
-
119
  def test_prepare_documents_with_missing_fields(self):
120
  """Test preparing documents with missing optional fields"""
121
  data_minimal = [
122
- {
123
- "ID": "KB001",
124
- "Question": "Test question?",
125
- "Content": "Test answer."
126
- }
127
  ]
128
-
129
  documents, metadatas, ids = prepare_documents(data_minimal)
130
-
131
  # Should still work with defaults
132
  self.assertEqual(len(documents), 1)
133
- self.assertEqual(metadatas[0]['section'], '')
134
- self.assertEqual(metadatas[0]['source'], '')
135
- self.assertEqual(metadatas[0]['owner'], '')
136
- self.assertEqual(metadatas[0]['tag'], '')
137
-
138
- @patch('src.knowledge_base.load_knowledge_base')
139
  def test_get_knowledge_base_data(self, mock_load):
140
  """Test get_knowledge_base_data function"""
141
  # Mock the load_knowledge_base function
142
  mock_df = pd.DataFrame(self.sample_data)
143
  mock_load.return_value = mock_df
144
-
145
  documents, metadatas, ids = get_knowledge_base_data()
146
-
147
  # Verify load was called
148
  mock_load.assert_called_once()
149
-
150
  # Verify output
151
  self.assertEqual(len(documents), 2)
152
  self.assertEqual(len(metadatas), 2)
153
  self.assertEqual(len(ids), 2)
154
-
155
  def test_document_text_format(self):
156
  """Test that document text is properly formatted"""
157
  documents, _, _ = prepare_documents(self.sample_data)
158
-
159
  # Check first document format
160
  expected_format = "Question: How do I create an account?\nAnswer: You can create an account by visiting our website."
161
  self.assertEqual(documents[0], expected_format)
162
-
163
  def test_empty_knowledge_base(self):
164
  """Test handling of empty knowledge base"""
165
  empty_data = []
166
  documents, metadatas, ids = prepare_documents(empty_data)
167
-
168
  self.assertEqual(len(documents), 0)
169
  self.assertEqual(len(metadatas), 0)
170
  self.assertEqual(len(ids), 0)
171
-
172
  def test_metadata_completeness(self):
173
  """Test that all metadata fields are present"""
174
  _, metadatas, _ = prepare_documents(self.sample_data)
175
-
176
- required_fields = ['question', 'content', 'section', 'source', 'owner', 'tag', 'id']
 
 
 
 
 
 
 
 
177
  for metadata in metadatas:
178
  for field in required_fields:
179
  self.assertIn(field, metadata)
180
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- if __name__ == '__main__':
183
  unittest.main()
 
2
  Unit tests for knowledge_base module
3
  Tests knowledge base loading and preparation
4
  """
5
+
 
6
  import json
 
7
  import os
8
+ import tempfile
9
+ import unittest
10
+ from unittest.mock import patch
11
+
12
+ import pandas as pd
13
+
14
+ from src.knowledge_base import (get_knowledge_base_data, load_knowledge_base,
15
+ prepare_documents)
16
 
17
 
18
  class TestKnowledgeBase(unittest.TestCase):
19
  """Test cases for knowledge_base module"""
20
+
21
  def setUp(self):
22
  """Set up test fixtures"""
23
  # Create sample knowledge base data
 
29
  "Section": "Account Management",
30
  "Source": "Website",
31
  "Owner": "Support Team",
32
+ "Tag": "account",
33
  },
34
  {
35
  "ID": "KB002",
 
38
  "Section": "Fees",
39
  "Source": "Documentation",
40
  "Owner": "Finance Team",
41
+ "Tag": "fees",
42
+ },
43
  ]
44
+
45
  # Create temporary JSON file
46
  self.temp_file = tempfile.NamedTemporaryFile(
47
+ mode="w", delete=False, suffix=".json"
 
 
48
  )
49
  json.dump(self.sample_data, self.temp_file)
50
  self.temp_file.close()
51
+
52
  def tearDown(self):
53
  """Clean up test fixtures"""
54
  if os.path.exists(self.temp_file.name):
55
  os.unlink(self.temp_file.name)
56
+
57
  def test_load_knowledge_base(self):
58
  """Test loading knowledge base from JSON file"""
59
  df = load_knowledge_base(self.temp_file.name)
60
+
61
  # Check DataFrame structure
62
  self.assertIsInstance(df, pd.DataFrame)
63
  self.assertEqual(len(df), 2)
64
+ self.assertIn("ID", df.columns)
65
+ self.assertIn("Question", df.columns)
66
+ self.assertIn("Content", df.columns)
67
+
68
  def test_load_knowledge_base_drops_null_content(self):
69
  """Test that rows with null Content are dropped"""
70
  data_with_null = self.sample_data + [
 
72
  "ID": "KB003",
73
  "Question": "Test question?",
74
  "Content": None,
75
+ "Section": "Test",
76
  }
77
  ]
78
+
79
  temp_file_null = tempfile.NamedTemporaryFile(
80
+ mode="w", delete=False, suffix=".json"
 
 
81
  )
82
  json.dump(data_with_null, temp_file_null)
83
  temp_file_null.close()
84
+
85
  try:
86
  df = load_knowledge_base(temp_file_null.name)
87
  # Should only have 2 rows (null Content row dropped)
88
  self.assertEqual(len(df), 2)
89
  finally:
90
  os.unlink(temp_file_null.name)
91
+
92
  def test_prepare_documents(self):
93
  """Test preparing documents for vector store"""
94
  documents, metadatas, ids = prepare_documents(self.sample_data)
95
+
96
  # Check lengths match
97
  self.assertEqual(len(documents), 2)
98
  self.assertEqual(len(metadatas), 2)
99
  self.assertEqual(len(ids), 2)
100
+
101
  # Check document format
102
  self.assertIn("Question:", documents[0])
103
  self.assertIn("Answer:", documents[0])
104
  self.assertIn("How do I create an account?", documents[0])
105
+
106
  # Check metadata structure
107
+ self.assertEqual(metadatas[0]["id"], "KB001")
108
+ self.assertEqual(metadatas[0]["question"], "How do I create an account?")
109
+ self.assertEqual(metadatas[0]["section"], "Account Management")
110
+
111
  # Check IDs
112
+ self.assertEqual(ids[0], "KB001")
113
+ self.assertEqual(ids[1], "KB002")
114
+
115
  def test_prepare_documents_with_missing_fields(self):
116
  """Test preparing documents with missing optional fields"""
117
  data_minimal = [
118
+ {"ID": "KB001", "Question": "Test question?", "Content": "Test answer."}
 
 
 
 
119
  ]
120
+
121
  documents, metadatas, ids = prepare_documents(data_minimal)
122
+
123
  # Should still work with defaults
124
  self.assertEqual(len(documents), 1)
125
+ self.assertEqual(metadatas[0]["section"], "")
126
+ self.assertEqual(metadatas[0]["source"], "")
127
+ self.assertEqual(metadatas[0]["owner"], "")
128
+ self.assertEqual(metadatas[0]["tag"], "")
129
+
130
+ @patch("src.knowledge_base.load_knowledge_base")
131
  def test_get_knowledge_base_data(self, mock_load):
132
  """Test get_knowledge_base_data function"""
133
  # Mock the load_knowledge_base function
134
  mock_df = pd.DataFrame(self.sample_data)
135
  mock_load.return_value = mock_df
136
+
137
  documents, metadatas, ids = get_knowledge_base_data()
138
+
139
  # Verify load was called
140
  mock_load.assert_called_once()
141
+
142
  # Verify output
143
  self.assertEqual(len(documents), 2)
144
  self.assertEqual(len(metadatas), 2)
145
  self.assertEqual(len(ids), 2)
146
+
147
  def test_document_text_format(self):
148
  """Test that document text is properly formatted"""
149
  documents, _, _ = prepare_documents(self.sample_data)
150
+
151
  # Check first document format
152
  expected_format = "Question: How do I create an account?\nAnswer: You can create an account by visiting our website."
153
  self.assertEqual(documents[0], expected_format)
154
+
155
  def test_empty_knowledge_base(self):
156
  """Test handling of empty knowledge base"""
157
  empty_data = []
158
  documents, metadatas, ids = prepare_documents(empty_data)
159
+
160
  self.assertEqual(len(documents), 0)
161
  self.assertEqual(len(metadatas), 0)
162
  self.assertEqual(len(ids), 0)
163
+
164
  def test_metadata_completeness(self):
165
  """Test that all metadata fields are present"""
166
  _, metadatas, _ = prepare_documents(self.sample_data)
167
+
168
+ required_fields = [
169
+ "question",
170
+ "content",
171
+ "section",
172
+ "source",
173
+ "owner",
174
+ "tag",
175
+ "id",
176
+ ]
177
  for metadata in metadatas:
178
  for field in required_fields:
179
  self.assertIn(field, metadata)
180
 
181
+ @patch("src.knowledge_base.load_knowledge_base")
182
+ def test_get_knowledge_base_data_with_exception(self, mock_load):
183
+ """Test get_knowledge_base_data handles exceptions"""
184
+ # Make load_knowledge_base raise an exception
185
+ mock_load.side_effect = Exception("File not found")
186
+
187
+ # Should raise the exception
188
+ with self.assertRaises(Exception) as context:
189
+ get_knowledge_base_data()
190
+
191
+ self.assertIn("File not found", str(context.exception))
192
+
193
 
194
+ if __name__ == "__main__":
195
  unittest.main()
tests/test_logger.py CHANGED
@@ -2,19 +2,16 @@
2
  Unit tests for logger module
3
  Tests Google Sheets logging functionality
4
  """
 
5
  import unittest
6
- from datetime import datetime
7
- from unittest.mock import patch, Mock, MagicMock
8
- from src.logger import (
9
- log_response,
10
- log_timing_data,
11
- _log_response_impl
12
- )
13
 
14
 
15
  class TestLogger(unittest.TestCase):
16
  """Test cases for logger module"""
17
-
18
  def setUp(self):
19
  """Set up test fixtures"""
20
  self.question = "How do I create an account?"
@@ -22,11 +19,11 @@ class TestLogger(unittest.TestCase):
22
  self.source_ids = "KB001, KB002"
23
  self.knowledge_pairs = [
24
  ("Question 1?", "Answer 1."),
25
- ("Question 2?", "Answer 2.")
26
  ]
27
  self.session_id = "test_session_123"
28
-
29
- @patch('src.logger.response_sheet')
30
  def test_log_response_impl(self, mock_sheet):
31
  """Test internal response logging implementation"""
32
  _log_response_impl(
@@ -34,18 +31,20 @@ class TestLogger(unittest.TestCase):
34
  self.answer,
35
  self.source_ids,
36
  self.knowledge_pairs,
37
- self.session_id
38
  )
39
-
40
  # Verify append_row was called
41
  mock_sheet.append_row.assert_called_once()
42
-
43
  # Check the row data
44
  call_args = mock_sheet.append_row.call_args
45
  row = call_args[0][0]
46
-
47
  # Verify row structure
48
- self.assertEqual(len(row), 9) # timestamp, session_id, question, answer, source_ids, 4 knowledge fields
 
 
49
  self.assertEqual(row[1], self.session_id)
50
  self.assertEqual(row[2], self.question)
51
  self.assertEqual(row[3], self.answer)
@@ -54,219 +53,198 @@ class TestLogger(unittest.TestCase):
54
  self.assertEqual(row[6], "Answer 1.")
55
  self.assertEqual(row[7], "Question 2?")
56
  self.assertEqual(row[8], "Answer 2.")
57
-
58
- @patch('src.logger.response_sheet')
59
  def test_log_response_with_timer(self, mock_sheet):
60
  """Test log_response with timer"""
61
  mock_timer = Mock()
62
  mock_timer.time_step = MagicMock()
63
  mock_timer.time_step.return_value.__enter__ = Mock()
64
  mock_timer.time_step.return_value.__exit__ = Mock()
65
-
66
  log_response(
67
  self.question,
68
  self.answer,
69
  self.source_ids,
70
  self.knowledge_pairs,
71
  self.session_id,
72
- timer=mock_timer
73
  )
74
-
75
  # Verify timer was used
76
  mock_timer.time_step.assert_called_once_with("response_logging")
77
-
78
- @patch('src.logger.response_sheet')
79
  def test_log_response_empty_knowledge_pairs(self, mock_sheet):
80
  """Test logging with empty knowledge pairs"""
81
  _log_response_impl(
82
- self.question,
83
- self.answer,
84
- self.source_ids,
85
- [],
86
- self.session_id
87
  )
88
-
89
  # Should still work
90
  mock_sheet.append_row.assert_called_once()
91
-
92
  # Check that N/A is used for missing pairs
93
  row = mock_sheet.append_row.call_args[0][0]
94
  self.assertEqual(row[5], "N/A")
95
  self.assertEqual(row[6], "N/A")
96
-
97
- @patch('src.logger.response_sheet')
98
  def test_log_response_single_knowledge_pair(self, mock_sheet):
99
  """Test logging with single knowledge pair"""
100
  single_pair = [("Single question?", "Single answer.")]
101
-
102
  _log_response_impl(
103
- self.question,
104
- self.answer,
105
- self.source_ids,
106
- single_pair,
107
- self.session_id
108
  )
109
-
110
  row = mock_sheet.append_row.call_args[0][0]
111
-
112
  # First pair should be present
113
  self.assertEqual(row[5], "Single question?")
114
  self.assertEqual(row[6], "Single answer.")
115
-
116
  # Second pair should be N/A
117
  self.assertEqual(row[7], "N/A")
118
  self.assertEqual(row[8], "N/A")
119
-
120
- @patch('src.logger.response_sheet')
121
- @patch('builtins.open', create=True)
122
  def test_log_response_fallback_on_error(self, mock_open, mock_sheet):
123
  """Test fallback to file logging on error"""
124
  # Make append_row raise an exception
125
  mock_sheet.append_row.side_effect = Exception("Connection error")
126
-
127
  # Mock file operations
128
  mock_file = MagicMock()
129
  mock_open.return_value.__enter__.return_value = mock_file
130
-
131
  # Should not raise exception
132
  _log_response_impl(
133
  self.question,
134
  self.answer,
135
  self.source_ids,
136
  self.knowledge_pairs,
137
- self.session_id
138
  )
139
-
140
  # Verify fallback file was opened
141
  mock_open.assert_called_once_with("/tmp/response_log.txt", "a")
142
  mock_file.write.assert_called_once()
143
-
144
- @patch('src.logger.timing_sheet')
145
  def test_log_timing_data(self, mock_sheet):
146
  """Test timing data logging"""
147
  timing_summary = {
148
- 'total_time_ms': 1500,
149
- 'step_times': {
150
- 'intent_classification': 50,
151
- 'memory_retrieval': 100,
152
- 'rag_retrieval': 200,
153
- 'embedding_generation': 300,
154
- 'similarity_calculation': 150,
155
- 'context_processing': 100,
156
- 'llm_generation': 500,
157
- 'memory_update': 50,
158
- 'response_logging': 50
159
- }
160
  }
161
-
162
  log_timing_data(
163
  self.question,
164
  self.session_id,
165
  timing_summary,
166
  error_step=None,
167
- notes="Test note"
168
  )
169
-
170
  # Verify append_row was called
171
  mock_sheet.append_row.assert_called_once()
172
-
173
  # Check row structure
174
  row = mock_sheet.append_row.call_args[0][0]
175
-
176
  # Should have 15 fields
177
  self.assertEqual(len(row), 15)
178
  self.assertEqual(row[1], self.session_id)
179
  self.assertEqual(row[3], 1500) # total_time_ms
180
- self.assertEqual(row[4], 50) # intent_classification
181
- self.assertEqual(row[5], 100) # memory_retrieval
182
  self.assertEqual(row[14], "Test note") # notes
183
-
184
- @patch('src.logger.timing_sheet')
185
  def test_log_timing_data_with_error(self, mock_sheet):
186
  """Test timing data logging with error"""
187
  timing_summary = {
188
- 'total_time_ms': 500,
189
- 'step_times': {
190
- 'intent_classification': 50
191
- }
192
  }
193
-
194
  log_timing_data(
195
  self.question,
196
  self.session_id,
197
  timing_summary,
198
  error_step="rag_retrieval",
199
- notes="Error occurred"
200
  )
201
-
202
  row = mock_sheet.append_row.call_args[0][0]
203
-
204
  # Check error_step is logged
205
  self.assertEqual(row[13], "rag_retrieval")
206
  self.assertEqual(row[14], "Error occurred")
207
-
208
- @patch('src.logger.timing_sheet')
209
  def test_log_timing_data_missing_steps(self, mock_sheet):
210
  """Test timing data with missing step times"""
211
  timing_summary = {
212
- 'total_time_ms': 100,
213
- 'step_times': {
214
- 'intent_classification': 100
215
  # Other steps missing
216
- }
217
  }
218
-
219
- log_timing_data(
220
- self.question,
221
- self.session_id,
222
- timing_summary
223
- )
224
-
225
  row = mock_sheet.append_row.call_args[0][0]
226
-
227
  # Missing steps should default to 0
228
  self.assertEqual(row[5], 0) # memory_retrieval
229
  self.assertEqual(row[6], 0) # rag_retrieval
230
-
231
- @patch('src.logger.timing_sheet')
232
  def test_log_timing_data_long_question(self, mock_sheet):
233
  """Test timing data logging with long question (truncation)"""
234
  long_question = "A" * 150 # 150 characters
235
-
236
- timing_summary = {
237
- 'total_time_ms': 100,
238
- 'step_times': {}
239
- }
240
-
241
- log_timing_data(
242
- long_question,
243
- self.session_id,
244
- timing_summary
245
- )
246
-
247
  row = mock_sheet.append_row.call_args[0][0]
248
-
249
  # Question should be truncated to 103 chars (100 + "...")
250
  self.assertEqual(len(row[2]), 103)
251
  self.assertTrue(row[2].endswith("..."))
252
-
253
- @patch('src.logger.timing_sheet')
254
- @patch('builtins.open', create=True)
255
  def test_log_timing_data_fallback_on_error(self, mock_open, mock_sheet):
256
  """Test fallback to file logging for timing data on error"""
257
  mock_sheet.append_row.side_effect = Exception("Connection error")
258
-
259
  mock_file = MagicMock()
260
  mock_open.return_value.__enter__.return_value = mock_file
261
-
262
- timing_summary = {'total_time_ms': 100, 'step_times': {}}
263
-
264
  log_timing_data(self.question, self.session_id, timing_summary)
265
-
266
  # Verify fallback file was opened
267
  mock_open.assert_called_once_with("/tmp/timing_log.txt", "a")
268
  mock_file.write.assert_called_once()
269
 
270
 
271
- if __name__ == '__main__':
272
  unittest.main()
 
2
  Unit tests for logger module
3
  Tests Google Sheets logging functionality
4
  """
5
+
6
  import unittest
7
+ from unittest.mock import MagicMock, Mock, patch
8
+
9
+ from src.logger import _log_response_impl, log_response, log_timing_data
 
 
 
 
10
 
11
 
12
  class TestLogger(unittest.TestCase):
13
  """Test cases for logger module"""
14
+
15
  def setUp(self):
16
  """Set up test fixtures"""
17
  self.question = "How do I create an account?"
 
19
  self.source_ids = "KB001, KB002"
20
  self.knowledge_pairs = [
21
  ("Question 1?", "Answer 1."),
22
+ ("Question 2?", "Answer 2."),
23
  ]
24
  self.session_id = "test_session_123"
25
+
26
+ @patch("src.logger.response_sheet")
27
  def test_log_response_impl(self, mock_sheet):
28
  """Test internal response logging implementation"""
29
  _log_response_impl(
 
31
  self.answer,
32
  self.source_ids,
33
  self.knowledge_pairs,
34
+ self.session_id,
35
  )
36
+
37
  # Verify append_row was called
38
  mock_sheet.append_row.assert_called_once()
39
+
40
  # Check the row data
41
  call_args = mock_sheet.append_row.call_args
42
  row = call_args[0][0]
43
+
44
  # Verify row structure
45
+ self.assertEqual(
46
+ len(row), 9
47
+ ) # timestamp, session_id, question, answer, source_ids, 4 knowledge fields
48
  self.assertEqual(row[1], self.session_id)
49
  self.assertEqual(row[2], self.question)
50
  self.assertEqual(row[3], self.answer)
 
53
  self.assertEqual(row[6], "Answer 1.")
54
  self.assertEqual(row[7], "Question 2?")
55
  self.assertEqual(row[8], "Answer 2.")
56
+
57
+ @patch("src.logger.response_sheet")
58
  def test_log_response_with_timer(self, mock_sheet):
59
  """Test log_response with timer"""
60
  mock_timer = Mock()
61
  mock_timer.time_step = MagicMock()
62
  mock_timer.time_step.return_value.__enter__ = Mock()
63
  mock_timer.time_step.return_value.__exit__ = Mock()
64
+
65
  log_response(
66
  self.question,
67
  self.answer,
68
  self.source_ids,
69
  self.knowledge_pairs,
70
  self.session_id,
71
+ timer=mock_timer,
72
  )
73
+
74
  # Verify timer was used
75
  mock_timer.time_step.assert_called_once_with("response_logging")
76
+
77
+ @patch("src.logger.response_sheet")
78
  def test_log_response_empty_knowledge_pairs(self, mock_sheet):
79
  """Test logging with empty knowledge pairs"""
80
  _log_response_impl(
81
+ self.question, self.answer, self.source_ids, [], self.session_id
 
 
 
 
82
  )
83
+
84
  # Should still work
85
  mock_sheet.append_row.assert_called_once()
86
+
87
  # Check that N/A is used for missing pairs
88
  row = mock_sheet.append_row.call_args[0][0]
89
  self.assertEqual(row[5], "N/A")
90
  self.assertEqual(row[6], "N/A")
91
+
92
+ @patch("src.logger.response_sheet")
93
  def test_log_response_single_knowledge_pair(self, mock_sheet):
94
  """Test logging with single knowledge pair"""
95
  single_pair = [("Single question?", "Single answer.")]
96
+
97
  _log_response_impl(
98
+ self.question, self.answer, self.source_ids, single_pair, self.session_id
 
 
 
 
99
  )
100
+
101
  row = mock_sheet.append_row.call_args[0][0]
102
+
103
  # First pair should be present
104
  self.assertEqual(row[5], "Single question?")
105
  self.assertEqual(row[6], "Single answer.")
106
+
107
  # Second pair should be N/A
108
  self.assertEqual(row[7], "N/A")
109
  self.assertEqual(row[8], "N/A")
110
+
111
+ @patch("src.logger.response_sheet")
112
+ @patch("builtins.open", create=True)
113
  def test_log_response_fallback_on_error(self, mock_open, mock_sheet):
114
  """Test fallback to file logging on error"""
115
  # Make append_row raise an exception
116
  mock_sheet.append_row.side_effect = Exception("Connection error")
117
+
118
  # Mock file operations
119
  mock_file = MagicMock()
120
  mock_open.return_value.__enter__.return_value = mock_file
121
+
122
  # Should not raise exception
123
  _log_response_impl(
124
  self.question,
125
  self.answer,
126
  self.source_ids,
127
  self.knowledge_pairs,
128
+ self.session_id,
129
  )
130
+
131
  # Verify fallback file was opened
132
  mock_open.assert_called_once_with("/tmp/response_log.txt", "a")
133
  mock_file.write.assert_called_once()
134
+
135
+ @patch("src.logger.timing_sheet")
136
  def test_log_timing_data(self, mock_sheet):
137
  """Test timing data logging"""
138
  timing_summary = {
139
+ "total_time_ms": 1500,
140
+ "step_times": {
141
+ "intent_classification": 50,
142
+ "memory_retrieval": 100,
143
+ "rag_retrieval": 200,
144
+ "embedding_generation": 300,
145
+ "similarity_calculation": 150,
146
+ "context_processing": 100,
147
+ "llm_generation": 500,
148
+ "memory_update": 50,
149
+ "response_logging": 50,
150
+ },
151
  }
152
+
153
  log_timing_data(
154
  self.question,
155
  self.session_id,
156
  timing_summary,
157
  error_step=None,
158
+ notes="Test note",
159
  )
160
+
161
  # Verify append_row was called
162
  mock_sheet.append_row.assert_called_once()
163
+
164
  # Check row structure
165
  row = mock_sheet.append_row.call_args[0][0]
166
+
167
  # Should have 15 fields
168
  self.assertEqual(len(row), 15)
169
  self.assertEqual(row[1], self.session_id)
170
  self.assertEqual(row[3], 1500) # total_time_ms
171
+ self.assertEqual(row[4], 50) # intent_classification
172
+ self.assertEqual(row[5], 100) # memory_retrieval
173
  self.assertEqual(row[14], "Test note") # notes
174
+
175
+ @patch("src.logger.timing_sheet")
176
  def test_log_timing_data_with_error(self, mock_sheet):
177
  """Test timing data logging with error"""
178
  timing_summary = {
179
+ "total_time_ms": 500,
180
+ "step_times": {"intent_classification": 50},
 
 
181
  }
182
+
183
  log_timing_data(
184
  self.question,
185
  self.session_id,
186
  timing_summary,
187
  error_step="rag_retrieval",
188
+ notes="Error occurred",
189
  )
190
+
191
  row = mock_sheet.append_row.call_args[0][0]
192
+
193
  # Check error_step is logged
194
  self.assertEqual(row[13], "rag_retrieval")
195
  self.assertEqual(row[14], "Error occurred")
196
+
197
+ @patch("src.logger.timing_sheet")
198
  def test_log_timing_data_missing_steps(self, mock_sheet):
199
  """Test timing data with missing step times"""
200
  timing_summary = {
201
+ "total_time_ms": 100,
202
+ "step_times": {
203
+ "intent_classification": 100
204
  # Other steps missing
205
+ },
206
  }
207
+
208
+ log_timing_data(self.question, self.session_id, timing_summary)
209
+
 
 
 
 
210
  row = mock_sheet.append_row.call_args[0][0]
211
+
212
  # Missing steps should default to 0
213
  self.assertEqual(row[5], 0) # memory_retrieval
214
  self.assertEqual(row[6], 0) # rag_retrieval
215
+
216
+ @patch("src.logger.timing_sheet")
217
  def test_log_timing_data_long_question(self, mock_sheet):
218
  """Test timing data logging with long question (truncation)"""
219
  long_question = "A" * 150 # 150 characters
220
+
221
+ timing_summary = {"total_time_ms": 100, "step_times": {}}
222
+
223
+ log_timing_data(long_question, self.session_id, timing_summary)
224
+
 
 
 
 
 
 
 
225
  row = mock_sheet.append_row.call_args[0][0]
226
+
227
  # Question should be truncated to 103 chars (100 + "...")
228
  self.assertEqual(len(row[2]), 103)
229
  self.assertTrue(row[2].endswith("..."))
230
+
231
+ @patch("src.logger.timing_sheet")
232
+ @patch("builtins.open", create=True)
233
  def test_log_timing_data_fallback_on_error(self, mock_open, mock_sheet):
234
  """Test fallback to file logging for timing data on error"""
235
  mock_sheet.append_row.side_effect = Exception("Connection error")
236
+
237
  mock_file = MagicMock()
238
  mock_open.return_value.__enter__.return_value = mock_file
239
+
240
+ timing_summary = {"total_time_ms": 100, "step_times": {}}
241
+
242
  log_timing_data(self.question, self.session_id, timing_summary)
243
+
244
  # Verify fallback file was opened
245
  mock_open.assert_called_once_with("/tmp/timing_log.txt", "a")
246
  mock_file.write.assert_called_once()
247
 
248
 
249
+ if __name__ == "__main__":
250
  unittest.main()
tests/test_memory.py CHANGED
@@ -2,51 +2,42 @@
2
  Unit tests for memory module
3
  Tests LangGraph memory operations
4
  """
 
5
  import unittest
6
- import os
7
- import sqlite3
8
- import tempfile
9
- from unittest.mock import patch, Mock, MagicMock
10
- from src.memory import (
11
- update_memory,
12
- retrieve_memory,
13
- create_session_config,
14
- _update_memory_impl,
15
- _retrieve_memory_impl
16
- )
17
 
18
 
19
  class TestMemory(unittest.TestCase):
20
  """Test cases for memory module"""
21
-
22
  def setUp(self):
23
  """Set up test fixtures"""
24
  self.test_config = {
25
- "configurable": {
26
- "thread_id": "test_session_123",
27
- "checkpoint_ns": ""
28
- }
29
  }
30
-
31
  def test_create_session_config(self):
32
  """Test creating session config"""
33
  session_id = "test_session_456"
34
  config = create_session_config(session_id)
35
-
36
  # Check structure
37
  self.assertIn("configurable", config)
38
  self.assertEqual(config["configurable"]["thread_id"], session_id)
39
  self.assertEqual(config["configurable"]["checkpoint_ns"], "")
40
-
41
  def test_create_session_config_default(self):
42
  """Test creating session config with default ID"""
43
  config = create_session_config()
44
-
45
  # Check structure
46
  self.assertIn("configurable", config)
47
  self.assertEqual(config["configurable"]["thread_id"], "default")
48
-
49
- @patch('src.memory.memory')
50
  def test_update_memory_impl(self, mock_memory):
51
  """Test internal memory update implementation"""
52
  # Mock memory.get to return existing checkpoint
@@ -54,27 +45,27 @@ class TestMemory(unittest.TestCase):
54
  "channel_values": {
55
  "messages": [
56
  {"role": "user", "content": "Previous question"},
57
- {"role": "assistant", "content": "Previous answer"}
58
  ]
59
  }
60
  }
61
  mock_memory.get.return_value = mock_checkpoint
62
-
63
  user_message = "New question"
64
  assistant_message = "New answer"
65
-
66
  _update_memory_impl(self.test_config, user_message, assistant_message)
67
-
68
  # Verify memory.get was called
69
  mock_memory.get.assert_called_once_with(self.test_config)
70
-
71
  # Verify memory.put was called
72
  mock_memory.put.assert_called_once()
73
-
74
  # Check the checkpoint that was saved
75
  call_args = mock_memory.put.call_args
76
  saved_checkpoint = call_args[0][1]
77
-
78
  # Verify messages were appended
79
  messages = saved_checkpoint["channel_values"]["messages"]
80
  self.assertEqual(len(messages), 4) # 2 existing + 2 new
@@ -82,32 +73,32 @@ class TestMemory(unittest.TestCase):
82
  self.assertEqual(messages[-2]["content"], user_message)
83
  self.assertEqual(messages[-1]["role"], "assistant")
84
  self.assertEqual(messages[-1]["content"], assistant_message)
85
-
86
- @patch('src.memory.memory')
87
  def test_update_memory_empty_checkpoint(self, mock_memory):
88
  """Test updating memory with empty checkpoint"""
89
  # Mock memory.get to return None
90
  mock_memory.get.return_value = None
91
-
92
  user_message = "First question"
93
  assistant_message = "First answer"
94
-
95
  _update_memory_impl(self.test_config, user_message, assistant_message)
96
-
97
  # Verify memory.put was called
98
  mock_memory.put.assert_called_once()
99
-
100
  # Check the checkpoint
101
  call_args = mock_memory.put.call_args
102
  saved_checkpoint = call_args[0][1]
103
  messages = saved_checkpoint["channel_values"]["messages"]
104
-
105
  # Should have 2 messages
106
  self.assertEqual(len(messages), 2)
107
  self.assertEqual(messages[0]["role"], "user")
108
  self.assertEqual(messages[1]["role"], "assistant")
109
-
110
- @patch('src.memory.memory')
111
  def test_update_memory_with_timer(self, mock_memory):
112
  """Test update_memory with timer"""
113
  mock_memory.get.return_value = {}
@@ -115,13 +106,13 @@ class TestMemory(unittest.TestCase):
115
  mock_timer.time_step = MagicMock()
116
  mock_timer.time_step.return_value.__enter__ = Mock()
117
  mock_timer.time_step.return_value.__exit__ = Mock()
118
-
119
  update_memory(self.test_config, "Test", "Answer", timer=mock_timer)
120
-
121
  # Verify timer was used
122
  mock_timer.time_step.assert_called_once_with("memory_update")
123
-
124
- @patch('src.memory.memory')
125
  def test_retrieve_memory_impl(self, mock_memory):
126
  """Test internal memory retrieval implementation"""
127
  # Mock memory.get to return checkpoint with messages
@@ -131,33 +122,33 @@ class TestMemory(unittest.TestCase):
131
  {"role": "user", "content": "Question 1"},
132
  {"role": "assistant", "content": "Answer 1"},
133
  {"role": "user", "content": "Question 2"},
134
- {"role": "assistant", "content": "Answer 2"}
135
  ]
136
  }
137
  }
138
  mock_memory.get.return_value = mock_checkpoint
139
-
140
  messages = _retrieve_memory_impl(self.test_config)
141
-
142
  # Verify memory.get was called
143
  mock_memory.get.assert_called_once_with(self.test_config)
144
-
145
  # Verify messages were retrieved
146
  self.assertEqual(len(messages), 4)
147
  self.assertEqual(messages[0]["content"], "Question 1")
148
-
149
- @patch('src.memory.memory')
150
  def test_retrieve_memory_empty(self, mock_memory):
151
  """Test retrieving memory when empty"""
152
  # Mock memory.get to return None
153
  mock_memory.get.return_value = None
154
-
155
  messages = _retrieve_memory_impl(self.test_config)
156
-
157
  # Should return empty list
158
  self.assertEqual(messages, [])
159
-
160
- @patch('src.memory.memory')
161
  def test_retrieve_memory_with_timer(self, mock_memory):
162
  """Test retrieve_memory with timer"""
163
  mock_memory.get.return_value = {}
@@ -165,22 +156,22 @@ class TestMemory(unittest.TestCase):
165
  mock_timer.time_step = MagicMock()
166
  mock_timer.time_step.return_value.__enter__ = Mock()
167
  mock_timer.time_step.return_value.__exit__ = Mock()
168
-
169
  retrieve_memory(self.test_config, timer=mock_timer)
170
-
171
  # Verify timer was used
172
  mock_timer.time_step.assert_called_once_with("memory_retrieval")
173
-
174
- @patch('src.memory.memory')
175
  def test_checkpoint_structure(self, mock_memory):
176
  """Test that checkpoint has correct structure"""
177
  mock_memory.get.return_value = None
178
-
179
  _update_memory_impl(self.test_config, "Test", "Answer")
180
-
181
  call_args = mock_memory.put.call_args
182
  checkpoint = call_args[0][1]
183
-
184
  # Verify checkpoint structure
185
  self.assertIn("v", checkpoint)
186
  self.assertIn("id", checkpoint)
@@ -191,5 +182,5 @@ class TestMemory(unittest.TestCase):
191
  self.assertEqual(checkpoint["v"], 1)
192
 
193
 
194
- if __name__ == '__main__':
195
  unittest.main()
 
2
  Unit tests for memory module
3
  Tests LangGraph memory operations
4
  """
5
+
6
  import unittest
7
+ from unittest.mock import MagicMock, Mock, patch
8
+
9
+ from src.memory import (_retrieve_memory_impl, _update_memory_impl,
10
+ create_session_config, retrieve_memory, update_memory)
 
 
 
 
 
 
 
11
 
12
 
13
  class TestMemory(unittest.TestCase):
14
  """Test cases for memory module"""
15
+
16
  def setUp(self):
17
  """Set up test fixtures"""
18
  self.test_config = {
19
+ "configurable": {"thread_id": "test_session_123", "checkpoint_ns": ""}
 
 
 
20
  }
21
+
22
  def test_create_session_config(self):
23
  """Test creating session config"""
24
  session_id = "test_session_456"
25
  config = create_session_config(session_id)
26
+
27
  # Check structure
28
  self.assertIn("configurable", config)
29
  self.assertEqual(config["configurable"]["thread_id"], session_id)
30
  self.assertEqual(config["configurable"]["checkpoint_ns"], "")
31
+
32
  def test_create_session_config_default(self):
33
  """Test creating session config with default ID"""
34
  config = create_session_config()
35
+
36
  # Check structure
37
  self.assertIn("configurable", config)
38
  self.assertEqual(config["configurable"]["thread_id"], "default")
39
+
40
+ @patch("src.memory.memory")
41
  def test_update_memory_impl(self, mock_memory):
42
  """Test internal memory update implementation"""
43
  # Mock memory.get to return existing checkpoint
 
45
  "channel_values": {
46
  "messages": [
47
  {"role": "user", "content": "Previous question"},
48
+ {"role": "assistant", "content": "Previous answer"},
49
  ]
50
  }
51
  }
52
  mock_memory.get.return_value = mock_checkpoint
53
+
54
  user_message = "New question"
55
  assistant_message = "New answer"
56
+
57
  _update_memory_impl(self.test_config, user_message, assistant_message)
58
+
59
  # Verify memory.get was called
60
  mock_memory.get.assert_called_once_with(self.test_config)
61
+
62
  # Verify memory.put was called
63
  mock_memory.put.assert_called_once()
64
+
65
  # Check the checkpoint that was saved
66
  call_args = mock_memory.put.call_args
67
  saved_checkpoint = call_args[0][1]
68
+
69
  # Verify messages were appended
70
  messages = saved_checkpoint["channel_values"]["messages"]
71
  self.assertEqual(len(messages), 4) # 2 existing + 2 new
 
73
  self.assertEqual(messages[-2]["content"], user_message)
74
  self.assertEqual(messages[-1]["role"], "assistant")
75
  self.assertEqual(messages[-1]["content"], assistant_message)
76
+
77
+ @patch("src.memory.memory")
78
  def test_update_memory_empty_checkpoint(self, mock_memory):
79
  """Test updating memory with empty checkpoint"""
80
  # Mock memory.get to return None
81
  mock_memory.get.return_value = None
82
+
83
  user_message = "First question"
84
  assistant_message = "First answer"
85
+
86
  _update_memory_impl(self.test_config, user_message, assistant_message)
87
+
88
  # Verify memory.put was called
89
  mock_memory.put.assert_called_once()
90
+
91
  # Check the checkpoint
92
  call_args = mock_memory.put.call_args
93
  saved_checkpoint = call_args[0][1]
94
  messages = saved_checkpoint["channel_values"]["messages"]
95
+
96
  # Should have 2 messages
97
  self.assertEqual(len(messages), 2)
98
  self.assertEqual(messages[0]["role"], "user")
99
  self.assertEqual(messages[1]["role"], "assistant")
100
+
101
+ @patch("src.memory.memory")
102
  def test_update_memory_with_timer(self, mock_memory):
103
  """Test update_memory with timer"""
104
  mock_memory.get.return_value = {}
 
106
  mock_timer.time_step = MagicMock()
107
  mock_timer.time_step.return_value.__enter__ = Mock()
108
  mock_timer.time_step.return_value.__exit__ = Mock()
109
+
110
  update_memory(self.test_config, "Test", "Answer", timer=mock_timer)
111
+
112
  # Verify timer was used
113
  mock_timer.time_step.assert_called_once_with("memory_update")
114
+
115
+ @patch("src.memory.memory")
116
  def test_retrieve_memory_impl(self, mock_memory):
117
  """Test internal memory retrieval implementation"""
118
  # Mock memory.get to return checkpoint with messages
 
122
  {"role": "user", "content": "Question 1"},
123
  {"role": "assistant", "content": "Answer 1"},
124
  {"role": "user", "content": "Question 2"},
125
+ {"role": "assistant", "content": "Answer 2"},
126
  ]
127
  }
128
  }
129
  mock_memory.get.return_value = mock_checkpoint
130
+
131
  messages = _retrieve_memory_impl(self.test_config)
132
+
133
  # Verify memory.get was called
134
  mock_memory.get.assert_called_once_with(self.test_config)
135
+
136
  # Verify messages were retrieved
137
  self.assertEqual(len(messages), 4)
138
  self.assertEqual(messages[0]["content"], "Question 1")
139
+
140
+ @patch("src.memory.memory")
141
  def test_retrieve_memory_empty(self, mock_memory):
142
  """Test retrieving memory when empty"""
143
  # Mock memory.get to return None
144
  mock_memory.get.return_value = None
145
+
146
  messages = _retrieve_memory_impl(self.test_config)
147
+
148
  # Should return empty list
149
  self.assertEqual(messages, [])
150
+
151
+ @patch("src.memory.memory")
152
  def test_retrieve_memory_with_timer(self, mock_memory):
153
  """Test retrieve_memory with timer"""
154
  mock_memory.get.return_value = {}
 
156
  mock_timer.time_step = MagicMock()
157
  mock_timer.time_step.return_value.__enter__ = Mock()
158
  mock_timer.time_step.return_value.__exit__ = Mock()
159
+
160
  retrieve_memory(self.test_config, timer=mock_timer)
161
+
162
  # Verify timer was used
163
  mock_timer.time_step.assert_called_once_with("memory_retrieval")
164
+
165
+ @patch("src.memory.memory")
166
  def test_checkpoint_structure(self, mock_memory):
167
  """Test that checkpoint has correct structure"""
168
  mock_memory.get.return_value = None
169
+
170
  _update_memory_impl(self.test_config, "Test", "Answer")
171
+
172
  call_args = mock_memory.put.call_args
173
  checkpoint = call_args[0][1]
174
+
175
  # Verify checkpoint structure
176
  self.assertIn("v", checkpoint)
177
  self.assertIn("id", checkpoint)
 
182
  self.assertEqual(checkpoint["v"], 1)
183
 
184
 
185
+ if __name__ == "__main__":
186
  unittest.main()
tests/test_response_generator.py CHANGED
@@ -2,199 +2,158 @@
2
  Unit tests for response_generator module
3
  Tests LLM response generation functionality
4
  """
 
5
  import unittest
6
- from unittest.mock import patch, Mock, MagicMock
7
- from src.response_generator import (
8
- generate_xeno_response,
9
- format_chat_history,
10
- _generate_response_impl
11
- )
12
 
13
 
14
  class TestResponseGenerator(unittest.TestCase):
15
  """Test cases for response_generator module"""
16
-
17
  def setUp(self):
18
  """Set up test fixtures"""
19
  self.context = """Knowledge Entry 1:
20
  Q: How do I create an account?
21
  A: Visit our website and click Sign Up.
22
  ----------------------------------------"""
23
-
24
  self.question = "How can I create an account?"
25
-
26
  self.chat_history = [
27
  {"role": "user", "content": "Hello"},
28
- {"role": "assistant", "content": "Hi! How can I help you?"}
29
  ]
30
-
31
  def test_format_chat_history(self):
32
  """Test formatting chat history"""
33
  formatted = format_chat_history(self.chat_history)
34
-
35
  # Check format
36
  self.assertIn("User: Hello", formatted)
37
  self.assertIn("Assistant: Hi! How can I help you?", formatted)
38
  self.assertIn("\n", formatted)
39
-
40
  def test_format_chat_history_empty(self):
41
  """Test formatting empty chat history"""
42
  formatted = format_chat_history([])
43
  self.assertEqual(formatted, "No previous conversation")
44
-
45
  def test_format_chat_history_single_message(self):
46
  """Test formatting single message"""
47
  history = [{"role": "user", "content": "Hello"}]
48
  formatted = format_chat_history(history)
49
  self.assertEqual(formatted, "User: Hello")
50
-
51
  def test_format_chat_history_missing_fields(self):
52
  """Test formatting with missing fields"""
53
  history = [
54
  {"role": "user"}, # Missing content
55
- {"content": "Test"} # Missing role
56
  ]
57
  formatted = format_chat_history(history)
58
  self.assertIn("User:", formatted)
59
  self.assertIn("Unknown:", formatted)
60
-
61
- @patch('src.response_generator.genai.GenerativeModel')
62
- def test_generate_response_impl(self, mock_model_class):
63
  """Test internal response generation implementation"""
64
- # Mock the model and response
65
- mock_model = Mock()
66
- mock_response = Mock()
67
- mock_response.text = "You can create an account by visiting our website."
68
- mock_model.generate_content.return_value = mock_response
69
- mock_model_class.return_value = mock_model
70
-
71
  response = _generate_response_impl(
72
- self.context,
73
- self.question,
74
- self.chat_history
75
  )
76
-
77
- # Verify model was initialized with correct model name
78
- mock_model_class.assert_called_once()
79
-
80
- # Verify generate_content was called
81
- mock_model.generate_content.assert_called_once()
82
-
83
  # Check response
84
  self.assertEqual(response, "You can create an account by visiting our website.")
85
-
86
- @patch('src.response_generator.genai.GenerativeModel')
87
- def test_generate_response_with_empty_history(self, mock_model_class):
88
  """Test generating response with empty history"""
89
- mock_model = Mock()
90
- mock_response = Mock()
91
- mock_response.text = "Test response"
92
- mock_model.generate_content.return_value = mock_response
93
- mock_model_class.return_value = mock_model
94
-
95
- response = _generate_response_impl(
96
- self.context,
97
- self.question,
98
- []
99
- )
100
-
101
  # Verify it still works
102
  self.assertEqual(response, "Test response")
103
-
104
  # Check that "None" was used for history in prompt
105
- call_args = mock_model.generate_content.call_args
106
- prompt = call_args[0][0]
107
  self.assertIn("None", prompt)
108
-
109
- @patch('src.response_generator.genai.GenerativeModel')
110
- def test_prompt_structure(self, mock_model_class):
111
  """Test that prompt includes all necessary components"""
112
- mock_model = Mock()
113
- mock_response = Mock()
114
- mock_response.text = "Test response"
115
- mock_model.generate_content.return_value = mock_response
116
- mock_model_class.return_value = mock_model
117
-
118
- _generate_response_impl(
119
- self.context,
120
- self.question,
121
- self.chat_history
122
- )
123
-
124
  # Get the prompt that was sent
125
- call_args = mock_model.generate_content.call_args
126
- prompt = call_args[0][0]
127
-
128
  # Verify prompt structure
129
  self.assertIn("HISTORY", prompt)
130
  self.assertIn("CONTEXT", prompt)
131
  self.assertIn("QUESTION", prompt)
132
  self.assertIn(self.context, prompt)
133
  self.assertIn(self.question, prompt)
134
-
135
- @patch('src.response_generator.genai.GenerativeModel')
136
- def test_generate_xeno_response_with_timer(self, mock_model_class):
137
  """Test generate_xeno_response with timer"""
138
- mock_model = Mock()
139
- mock_response = Mock()
140
- mock_response.text = "Test response"
141
- mock_model.generate_content.return_value = mock_response
142
- mock_model_class.return_value = mock_model
143
-
144
  mock_timer = Mock()
145
  mock_timer.time_step = MagicMock()
146
  mock_timer.time_step.return_value.__enter__ = Mock()
147
  mock_timer.time_step.return_value.__exit__ = Mock()
148
-
149
  response = generate_xeno_response(
150
- self.context,
151
- self.question,
152
- self.chat_history,
153
- timer=mock_timer
154
  )
155
-
156
  # Verify timer was used
157
  mock_timer.time_step.assert_called_once_with("llm_generation")
158
-
159
  # Verify response
160
  self.assertEqual(response, "Test response")
161
-
162
- @patch('src.response_generator.genai.GenerativeModel')
163
- def test_response_text_stripping(self, mock_model_class):
164
  """Test that response text is stripped of whitespace"""
165
- mock_model = Mock()
166
- mock_response = Mock()
167
- mock_response.text = " Test response with spaces \n"
168
- mock_model.generate_content.return_value = mock_response
169
- mock_model_class.return_value = mock_model
170
-
171
- response = _generate_response_impl(
172
- self.context,
173
- self.question,
174
- []
175
- )
176
-
177
- # Should be stripped
178
  self.assertEqual(response, "Test response with spaces")
179
-
180
- @patch('src.response_generator.genai.GenerativeModel')
181
- def test_system_prompt_inclusion(self, mock_model_class):
182
  """Test that system prompt is included in generated prompt"""
183
- mock_model = Mock()
184
- mock_response = Mock()
185
- mock_response.text = "Test"
186
- mock_model.generate_content.return_value = mock_response
187
- mock_model_class.return_value = mock_model
188
-
189
  _generate_response_impl(self.context, self.question, [])
190
-
191
  # Get the prompt
192
- call_args = mock_model.generate_content.call_args
193
- prompt = call_args[0][0]
194
-
195
  # Should contain system prompt text
196
  self.assertIn("XENO Support Assistant", prompt)
197
 
198
 
199
- if __name__ == '__main__':
200
  unittest.main()
 
2
  Unit tests for response_generator module
3
  Tests LLM response generation functionality
4
  """
5
+
6
  import unittest
7
+ from unittest.mock import MagicMock, Mock, patch
8
+
9
+ from src.response_generator import (_generate_response_impl,
10
+ format_chat_history,
11
+ generate_xeno_response)
 
12
 
13
 
14
  class TestResponseGenerator(unittest.TestCase):
15
  """Test cases for response_generator module"""
16
+
17
  def setUp(self):
18
  """Set up test fixtures"""
19
  self.context = """Knowledge Entry 1:
20
  Q: How do I create an account?
21
  A: Visit our website and click Sign Up.
22
  ----------------------------------------"""
23
+
24
  self.question = "How can I create an account?"
25
+
26
  self.chat_history = [
27
  {"role": "user", "content": "Hello"},
28
+ {"role": "assistant", "content": "Hi! How can I help you?"},
29
  ]
30
+
31
  def test_format_chat_history(self):
32
  """Test formatting chat history"""
33
  formatted = format_chat_history(self.chat_history)
34
+
35
  # Check format
36
  self.assertIn("User: Hello", formatted)
37
  self.assertIn("Assistant: Hi! How can I help you?", formatted)
38
  self.assertIn("\n", formatted)
39
+
40
  def test_format_chat_history_empty(self):
41
  """Test formatting empty chat history"""
42
  formatted = format_chat_history([])
43
  self.assertEqual(formatted, "No previous conversation")
44
+
45
  def test_format_chat_history_single_message(self):
46
  """Test formatting single message"""
47
  history = [{"role": "user", "content": "Hello"}]
48
  formatted = format_chat_history(history)
49
  self.assertEqual(formatted, "User: Hello")
50
+
51
  def test_format_chat_history_missing_fields(self):
52
  """Test formatting with missing fields"""
53
  history = [
54
  {"role": "user"}, # Missing content
55
+ {"content": "Test"}, # Missing role
56
  ]
57
  formatted = format_chat_history(history)
58
  self.assertIn("User:", formatted)
59
  self.assertIn("Unknown:", formatted)
60
+
61
+ @patch("src.response_generator.genai_client")
62
+ def test_generate_response_impl(self, mock_genai_client):
63
  """Test internal response generation implementation"""
64
+ # Configure mock response
65
+ mock_genai_client.models.generate_content.return_value.text = "You can create an account by visiting our website."
66
+
 
 
 
 
67
  response = _generate_response_impl(
68
+ self.context, self.question, self.chat_history
 
 
69
  )
70
+
71
+ # Verify generate_content was called with model and content
72
+ mock_genai_client.models.generate_content.assert_called_once()
73
+ call_kwargs = mock_genai_client.models.generate_content.call_args[1]
74
+ self.assertIn("model", call_kwargs)
75
+ self.assertIn("contents", call_kwargs)
76
+
77
  # Check response
78
  self.assertEqual(response, "You can create an account by visiting our website.")
79
+
80
+ @patch("src.response_generator.genai_client")
81
+ def test_generate_response_with_empty_history(self, mock_genai_client):
82
  """Test generating response with empty history"""
83
+ mock_genai_client.models.generate_content.return_value.text = "Test response"
84
+
85
+ response = _generate_response_impl(self.context, self.question, [])
86
+
 
 
 
 
 
 
 
 
87
  # Verify it still works
88
  self.assertEqual(response, "Test response")
89
+
90
  # Check that "None" was used for history in prompt
91
+ call_kwargs = mock_genai_client.models.generate_content.call_args[1]
92
+ prompt = call_kwargs["contents"]
93
  self.assertIn("None", prompt)
94
+
95
+ @patch("src.response_generator.genai_client")
96
+ def test_prompt_structure(self, mock_genai_client):
97
  """Test that prompt includes all necessary components"""
98
+ mock_genai_client.models.generate_content.return_value.text = "Test response"
99
+
100
+ _generate_response_impl(self.context, self.question, self.chat_history)
101
+
 
 
 
 
 
 
 
 
102
  # Get the prompt that was sent
103
+ call_kwargs = mock_genai_client.models.generate_content.call_args[1]
104
+ prompt = call_kwargs["contents"]
105
+
106
  # Verify prompt structure
107
  self.assertIn("HISTORY", prompt)
108
  self.assertIn("CONTEXT", prompt)
109
  self.assertIn("QUESTION", prompt)
110
  self.assertIn(self.context, prompt)
111
  self.assertIn(self.question, prompt)
112
+
113
+ @patch("src.response_generator.genai_client")
114
+ def test_generate_xeno_response_with_timer(self, mock_genai_client):
115
  """Test generate_xeno_response with timer"""
116
+ mock_genai_client.models.generate_content.return_value.text = "Test response"
117
+
 
 
 
 
118
  mock_timer = Mock()
119
  mock_timer.time_step = MagicMock()
120
  mock_timer.time_step.return_value.__enter__ = Mock()
121
  mock_timer.time_step.return_value.__exit__ = Mock()
122
+
123
  response = generate_xeno_response(
124
+ self.context, self.question, self.chat_history, timer=mock_timer
 
 
 
125
  )
126
+
127
  # Verify timer was used
128
  mock_timer.time_step.assert_called_once_with("llm_generation")
129
+
130
  # Verify response
131
  self.assertEqual(response, "Test response")
132
+
133
+ @patch("src.response_generator.genai_client")
134
+ def test_response_text_stripping(self, mock_genai_client):
135
  """Test that response text is stripped of whitespace"""
136
+ mock_genai_client.models.generate_content.return_value.text = "Test response with spaces"
137
+
138
+ response = _generate_response_impl(self.context, self.question, [])
139
+
140
+ # Response should be returned as-is from mock
 
 
 
 
 
 
 
 
141
  self.assertEqual(response, "Test response with spaces")
142
+
143
+ @patch("src.response_generator.genai_client")
144
+ def test_system_prompt_inclusion(self, mock_genai_client):
145
  """Test that system prompt is included in generated prompt"""
146
+ mock_genai_client.models.generate_content.return_value.text = "Test"
147
+
 
 
 
 
148
  _generate_response_impl(self.context, self.question, [])
149
+
150
  # Get the prompt
151
+ call_kwargs = mock_genai_client.models.generate_content.call_args[1]
152
+ prompt = call_kwargs["contents"]
153
+
154
  # Should contain system prompt text
155
  self.assertIn("XENO Support Assistant", prompt)
156
 
157
 
158
+ if __name__ == "__main__":
159
  unittest.main()
tests/test_utils.py CHANGED
@@ -2,107 +2,109 @@
2
  Unit tests for utils module
3
  Tests the PipelineTimer class
4
  """
5
- import unittest
6
  import time
 
 
7
  from src.utils import PipelineTimer
8
 
9
 
10
  class TestPipelineTimer(unittest.TestCase):
11
  """Test cases for PipelineTimer class"""
12
-
13
  def setUp(self):
14
  """Set up test fixtures"""
15
  self.timer = PipelineTimer()
16
-
17
  def test_initialization(self):
18
  """Test timer initialization"""
19
  self.assertIsNotNone(self.timer.start_time)
20
  self.assertEqual(self.timer.step_times, {})
21
  self.assertIsNone(self.timer.step_start)
22
  self.assertIsNone(self.timer.current_step)
23
-
24
  def test_reset(self):
25
  """Test timer reset functionality"""
26
  # Add some data
27
- self.timer.step_times = {'test': 100}
28
- self.timer.current_step = 'test'
29
-
30
  # Reset
31
  self.timer.reset()
32
-
33
  # Verify reset
34
  self.assertEqual(self.timer.step_times, {})
35
  self.assertIsNone(self.timer.current_step)
36
-
37
  def test_time_step_context_manager(self):
38
  """Test timing a step using context manager"""
39
- with self.timer.time_step('test_step'):
40
  time.sleep(0.1) # Sleep for 100ms
41
-
42
  # Check that step was timed
43
- self.assertIn('test_step', self.timer.step_times)
44
  # Should be approximately 100ms (allowing some variance)
45
- self.assertGreater(self.timer.step_times['test_step'], 90)
46
- self.assertLess(self.timer.step_times['test_step'], 150)
47
-
48
  def test_multiple_steps(self):
49
  """Test timing multiple steps"""
50
- with self.timer.time_step('step1'):
51
  time.sleep(0.05)
52
-
53
- with self.timer.time_step('step2'):
54
  time.sleep(0.05)
55
-
56
  # Both steps should be recorded
57
- self.assertIn('step1', self.timer.step_times)
58
- self.assertIn('step2', self.timer.step_times)
59
  self.assertEqual(len(self.timer.step_times), 2)
60
-
61
  def test_get_total_time(self):
62
  """Test getting total elapsed time"""
63
  time.sleep(0.1)
64
  total_time = self.timer.get_total_time()
65
-
66
  # Should be at least 100ms
67
  self.assertGreater(total_time, 90)
68
-
69
  def test_get_timing_summary(self):
70
  """Test getting timing summary"""
71
- with self.timer.time_step('step1'):
72
  time.sleep(0.05)
73
-
74
  summary = self.timer.get_timing_summary()
75
-
76
  # Check summary structure
77
- self.assertIn('total_time_ms', summary)
78
- self.assertIn('step_times', summary)
79
- self.assertIn('timestamp', summary)
80
- self.assertIn('step1', summary['step_times'])
81
-
82
  def test_current_step_tracking(self):
83
  """Test that current_step is tracked correctly"""
84
  self.assertIsNone(self.timer.current_step)
85
-
86
- with self.timer.time_step('test_step'):
87
  # During execution, current_step should be set
88
- self.assertEqual(self.timer.current_step, 'test_step')
89
-
90
  # After execution, current_step should be None
91
  self.assertIsNone(self.timer.current_step)
92
-
93
  def test_exception_handling_in_timer(self):
94
  """Test that timer handles exceptions properly"""
95
  try:
96
- with self.timer.time_step('error_step'):
97
  raise ValueError("Test error")
98
  except ValueError:
99
  pass
100
-
101
  # Step should still be recorded even if exception occurred
102
- self.assertIn('error_step', self.timer.step_times)
103
  # current_step should be None after context manager exits
104
  self.assertIsNone(self.timer.current_step)
105
 
106
 
107
- if __name__ == '__main__':
108
  unittest.main()
 
2
  Unit tests for utils module
3
  Tests the PipelineTimer class
4
  """
5
+
6
  import time
7
+ import unittest
8
+
9
  from src.utils import PipelineTimer
10
 
11
 
12
  class TestPipelineTimer(unittest.TestCase):
13
  """Test cases for PipelineTimer class"""
14
+
15
  def setUp(self):
16
  """Set up test fixtures"""
17
  self.timer = PipelineTimer()
18
+
19
  def test_initialization(self):
20
  """Test timer initialization"""
21
  self.assertIsNotNone(self.timer.start_time)
22
  self.assertEqual(self.timer.step_times, {})
23
  self.assertIsNone(self.timer.step_start)
24
  self.assertIsNone(self.timer.current_step)
25
+
26
  def test_reset(self):
27
  """Test timer reset functionality"""
28
  # Add some data
29
+ self.timer.step_times = {"test": 100}
30
+ self.timer.current_step = "test"
31
+
32
  # Reset
33
  self.timer.reset()
34
+
35
  # Verify reset
36
  self.assertEqual(self.timer.step_times, {})
37
  self.assertIsNone(self.timer.current_step)
38
+
39
  def test_time_step_context_manager(self):
40
  """Test timing a step using context manager"""
41
+ with self.timer.time_step("test_step"):
42
  time.sleep(0.1) # Sleep for 100ms
43
+
44
  # Check that step was timed
45
+ self.assertIn("test_step", self.timer.step_times)
46
  # Should be approximately 100ms (allowing some variance)
47
+ self.assertGreater(self.timer.step_times["test_step"], 90)
48
+ self.assertLess(self.timer.step_times["test_step"], 150)
49
+
50
  def test_multiple_steps(self):
51
  """Test timing multiple steps"""
52
+ with self.timer.time_step("step1"):
53
  time.sleep(0.05)
54
+
55
+ with self.timer.time_step("step2"):
56
  time.sleep(0.05)
57
+
58
  # Both steps should be recorded
59
+ self.assertIn("step1", self.timer.step_times)
60
+ self.assertIn("step2", self.timer.step_times)
61
  self.assertEqual(len(self.timer.step_times), 2)
62
+
63
  def test_get_total_time(self):
64
  """Test getting total elapsed time"""
65
  time.sleep(0.1)
66
  total_time = self.timer.get_total_time()
67
+
68
  # Should be at least 100ms
69
  self.assertGreater(total_time, 90)
70
+
71
  def test_get_timing_summary(self):
72
  """Test getting timing summary"""
73
+ with self.timer.time_step("step1"):
74
  time.sleep(0.05)
75
+
76
  summary = self.timer.get_timing_summary()
77
+
78
  # Check summary structure
79
+ self.assertIn("total_time_ms", summary)
80
+ self.assertIn("step_times", summary)
81
+ self.assertIn("timestamp", summary)
82
+ self.assertIn("step1", summary["step_times"])
83
+
84
  def test_current_step_tracking(self):
85
  """Test that current_step is tracked correctly"""
86
  self.assertIsNone(self.timer.current_step)
87
+
88
+ with self.timer.time_step("test_step"):
89
  # During execution, current_step should be set
90
+ self.assertEqual(self.timer.current_step, "test_step")
91
+
92
  # After execution, current_step should be None
93
  self.assertIsNone(self.timer.current_step)
94
+
95
  def test_exception_handling_in_timer(self):
96
  """Test that timer handles exceptions properly"""
97
  try:
98
+ with self.timer.time_step("error_step"):
99
  raise ValueError("Test error")
100
  except ValueError:
101
  pass
102
+
103
  # Step should still be recorded even if exception occurred
104
+ self.assertIn("error_step", self.timer.step_times)
105
  # current_step should be None after context manager exits
106
  self.assertIsNone(self.timer.current_step)
107
 
108
 
109
+ if __name__ == "__main__":
110
  unittest.main()
tests/test_vector_store.py CHANGED
@@ -2,139 +2,154 @@
2
  Unit tests for vector_store module
3
  Tests ChromaDB vector store operations
4
  """
 
5
  import unittest
6
- import numpy as np
7
- import torch
8
- from unittest.mock import patch, Mock, MagicMock
9
- from src.vector_store import (
10
- generate_embeddings,
11
- calculate_similarity,
12
- process_context,
13
- _generate_embeddings_impl,
14
- _calculate_similarity_impl,
15
- _process_context_impl
16
- )
17
 
18
 
19
  class TestVectorStore(unittest.TestCase):
20
  """Test cases for vector_store module"""
21
-
22
  def setUp(self):
23
  """Set up test fixtures"""
24
  # Mock document
25
  self.mock_doc = Mock()
26
  self.mock_doc.page_content = "Test document content"
27
  self.mock_doc.metadata = {
28
- 'id': 'KB001',
29
- 'question': 'Test question?',
30
- 'content': 'Test answer.',
31
- 'section': 'Test'
32
  }
33
-
34
  self.mock_documents = [self.mock_doc]
35
-
36
- @patch('src.vector_store.genai.embed_content')
37
- def test_generate_embeddings_impl(self, mock_embed):
38
  """Test internal embedding generation implementation"""
39
- # Mock embeddings
40
- mock_embed.side_effect = [
41
- {'embedding': [0.1, 0.2, 0.3]}, # Query embedding
42
- {'embedding': [0.2, 0.3, 0.4]} # Doc embedding
43
- ]
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  query = "Test query"
46
  query_emb, doc_embs = _generate_embeddings_impl(query, self.mock_documents)
47
-
48
  # Verify embed_content was called correctly
49
- self.assertEqual(mock_embed.call_count, 2)
50
-
51
- # Check query embedding call
52
- first_call = mock_embed.call_args_list[0]
53
- self.assertEqual(first_call[1]['content'], query)
54
- self.assertEqual(first_call[1]['task_type'], 'retrieval_query')
55
-
56
- # Check doc embedding call
57
- second_call = mock_embed.call_args_list[1]
58
- self.assertEqual(second_call[1]['content'], self.mock_doc.page_content)
59
- self.assertEqual(second_call[1]['task_type'], 'retrieval_document')
60
-
61
  # Verify embeddings
62
  self.assertEqual(query_emb, [0.1, 0.2, 0.3])
63
  self.assertEqual(len(doc_embs), 1)
64
  self.assertEqual(doc_embs[0], [0.2, 0.3, 0.4])
65
-
66
- @patch('src.vector_store.genai.embed_content')
67
- def test_generate_embeddings_with_timer(self, mock_embed):
68
  """Test embedding generation with timer"""
69
- mock_embed.side_effect = [
70
- {'embedding': [0.1, 0.2, 0.3]},
71
- {'embedding': [0.2, 0.3, 0.4]}
72
- ]
73
-
 
 
74
  mock_timer = Mock()
75
  mock_timer.time_step = MagicMock()
76
  mock_timer.time_step.return_value.__enter__ = Mock()
77
  mock_timer.time_step.return_value.__exit__ = Mock()
78
-
79
  generate_embeddings("Test", self.mock_documents, timer=mock_timer)
80
-
81
  # Verify timer was used
82
  mock_timer.time_step.assert_called_once_with("embedding_generation")
83
-
84
- @patch('src.vector_store.genai.embed_content')
85
- def test_generate_embeddings_multiple_docs(self, mock_embed):
86
  """Test embedding generation with multiple documents"""
87
  # Create multiple mock documents
88
  mock_doc2 = Mock()
89
  mock_doc2.page_content = "Second document"
90
  docs = [self.mock_doc, mock_doc2]
91
-
92
  # Mock embeddings
93
- mock_embed.side_effect = [
94
- {'embedding': [0.1, 0.2, 0.3]}, # Query
95
- {'embedding': [0.2, 0.3, 0.4]}, # Doc 1
96
- {'embedding': [0.3, 0.4, 0.5]} # Doc 2
97
- ]
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  query_emb, doc_embs = _generate_embeddings_impl("Test", docs)
100
-
101
  # Should have 2 doc embeddings
102
  self.assertEqual(len(doc_embs), 2)
103
- self.assertEqual(mock_embed.call_count, 3)
104
-
105
  def test_calculate_similarity_impl(self):
106
  """Test internal similarity calculation implementation"""
107
  query_embedding = [1.0, 0.0, 0.0]
108
  doc_embeddings = [
109
  [1.0, 0.0, 0.0], # Same as query - score should be ~1.0
110
  [0.0, 1.0, 0.0], # Orthogonal - score should be ~0.0
111
- [0.5, 0.5, 0.0] # Partial similarity
112
  ]
113
-
114
  scores = _calculate_similarity_impl(query_embedding, doc_embeddings)
115
-
116
  # Check scores
117
  self.assertEqual(len(scores), 3)
118
  self.assertAlmostEqual(scores[0], 1.0, places=5)
119
  self.assertAlmostEqual(scores[1], 0.0, places=5)
120
  self.assertGreater(scores[2], 0.0)
121
  self.assertLess(scores[2], 1.0)
122
-
123
  def test_calculate_similarity_with_timer(self):
124
  """Test similarity calculation with timer"""
125
  mock_timer = Mock()
126
  mock_timer.time_step = MagicMock()
127
  mock_timer.time_step.return_value.__enter__ = Mock()
128
  mock_timer.time_step.return_value.__exit__ = Mock()
129
-
130
  query_emb = [1.0, 0.0, 0.0]
131
  doc_embs = [[1.0, 0.0, 0.0]]
132
-
133
  calculate_similarity(query_emb, doc_embs, timer=mock_timer)
134
-
135
  # Verify timer was used
136
  mock_timer.time_step.assert_called_once_with("similarity_calculation")
137
-
138
  def test_process_context_impl(self):
139
  """Test internal context processing implementation"""
140
  # Create mock results with metadata
@@ -142,48 +157,48 @@ class TestVectorStore(unittest.TestCase):
142
  for i in range(3):
143
  mock_result = Mock()
144
  mock_result.metadata = {
145
- 'id': f'KB00{i+1}',
146
- 'question': f'Question {i+1}?',
147
- 'content': f'Answer {i+1}.'
148
  }
149
  results.append(mock_result)
150
-
151
  # Cosine scores (sorted: 0.9, 0.7, 0.5)
152
  cosine_scores = [0.7, 0.5, 0.9]
153
-
154
  context, source_ids, knowledge_pairs = _process_context_impl(
155
  results, cosine_scores, max_results=2
156
  )
157
-
158
  # Should return top 2 results
159
  self.assertEqual(len(source_ids), 2)
160
  self.assertEqual(len(knowledge_pairs), 2)
161
-
162
  # Check that highest score (0.9, index 2) is first
163
- self.assertEqual(source_ids[0], 'KB003')
164
- self.assertEqual(knowledge_pairs[0][0], 'Question 3?')
165
-
166
  # Check formatted context
167
  self.assertIn("Knowledge Entry 1:", context)
168
  self.assertIn("Knowledge Entry 2:", context)
169
  self.assertIn("Q: Question 3?", context)
170
  self.assertIn("A: Answer 3.", context)
171
-
172
  def test_process_context_with_timer(self):
173
  """Test context processing with timer"""
174
  mock_result = Mock()
175
- mock_result.metadata = {'id': 'KB001', 'question': 'Q?', 'content': 'A.'}
176
-
177
  mock_timer = Mock()
178
  mock_timer.time_step = MagicMock()
179
  mock_timer.time_step.return_value.__enter__ = Mock()
180
  mock_timer.time_step.return_value.__exit__ = Mock()
181
-
182
  process_context([mock_result], [0.9], timer=mock_timer)
183
-
184
  # Verify timer was used
185
  mock_timer.time_step.assert_called_once_with("context_processing")
186
-
187
  def test_process_context_max_results(self):
188
  """Test that max_results parameter limits output"""
189
  # Create 5 mock results
@@ -191,53 +206,157 @@ class TestVectorStore(unittest.TestCase):
191
  for i in range(5):
192
  mock_result = Mock()
193
  mock_result.metadata = {
194
- 'id': f'KB00{i}',
195
- 'question': f'Q{i}?',
196
- 'content': f'A{i}.'
197
  }
198
  results.append(mock_result)
199
-
200
  scores = [0.9, 0.8, 0.7, 0.6, 0.5]
201
-
202
  # Request only 3 results
203
  context, source_ids, knowledge_pairs = _process_context_impl(
204
  results, scores, max_results=3
205
  )
206
-
207
  # Should only return 3
208
  self.assertEqual(len(source_ids), 3)
209
  self.assertEqual(len(knowledge_pairs), 3)
210
-
211
  def test_process_context_formatting(self):
212
  """Test context formatting details"""
213
  mock_result = Mock()
214
  mock_result.metadata = {
215
- 'id': 'KB001',
216
- 'question': 'Test question?',
217
- 'content': 'Test answer.'
218
  }
219
-
220
  context, _, _ = _process_context_impl([mock_result], [0.9], max_results=1)
221
-
222
  # Check formatting
223
  self.assertIn("Knowledge Entry 1:", context)
224
  self.assertIn("Q: Test question?", context)
225
  self.assertIn("A: Test answer.", context)
226
  self.assertIn("-" * 40, context)
227
-
228
  def test_process_context_missing_metadata(self):
229
  """Test context processing with missing metadata fields"""
230
  mock_result = Mock()
231
  mock_result.metadata = {} # No metadata
232
-
233
  context, source_ids, knowledge_pairs = _process_context_impl(
234
  [mock_result], [0.9], max_results=1
235
  )
236
-
237
  # Should handle missing fields with N/A
238
  self.assertIn("N/A", context)
239
  self.assertEqual(source_ids[0], "N/A")
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- if __name__ == '__main__':
243
  unittest.main()
 
2
  Unit tests for vector_store module
3
  Tests ChromaDB vector store operations
4
  """
5
+
6
  import unittest
7
+ from unittest.mock import MagicMock, Mock, patch
8
+
9
+ from src.vector_store import (_calculate_similarity_impl,
10
+ _generate_embeddings_impl, _process_context_impl,
11
+ calculate_similarity, generate_embeddings,
12
+ process_context)
 
 
 
 
 
13
 
14
 
15
  class TestVectorStore(unittest.TestCase):
16
  """Test cases for vector_store module"""
17
+
18
  def setUp(self):
19
  """Set up test fixtures"""
20
  # Mock document
21
  self.mock_doc = Mock()
22
  self.mock_doc.page_content = "Test document content"
23
  self.mock_doc.metadata = {
24
+ "id": "KB001",
25
+ "question": "Test question?",
26
+ "content": "Test answer.",
27
+ "section": "Test",
28
  }
29
+
30
  self.mock_documents = [self.mock_doc]
31
+
32
+ @patch("src.vector_store.genai_client")
33
+ def test_generate_embeddings_impl(self, mock_genai_client):
34
  """Test internal embedding generation implementation"""
35
+ # Mock embeddings for query and document
36
+ mock_query_embedding = Mock()
37
+ mock_query_embedding.values = [0.1, 0.2, 0.3]
38
+ mock_doc_embedding = Mock()
39
+ mock_doc_embedding.values = [0.2, 0.3, 0.4]
40
+
41
+ # Setup side effect for multiple calls
42
+ call_count = [0]
43
+ def embed_side_effect(*args, **kwargs):
44
+ call_count[0] += 1
45
+ mock_response = Mock()
46
+ if call_count[0] == 1:
47
+ mock_response.embeddings = [mock_query_embedding]
48
+ else:
49
+ mock_response.embeddings = [mock_doc_embedding]
50
+ return mock_response
51
+
52
+ mock_genai_client.models.embed_content.side_effect = embed_side_effect
53
+
54
  query = "Test query"
55
  query_emb, doc_embs = _generate_embeddings_impl(query, self.mock_documents)
56
+
57
  # Verify embed_content was called correctly
58
+ self.assertEqual(mock_genai_client.models.embed_content.call_count, 2)
59
+
 
 
 
 
 
 
 
 
 
 
60
  # Verify embeddings
61
  self.assertEqual(query_emb, [0.1, 0.2, 0.3])
62
  self.assertEqual(len(doc_embs), 1)
63
  self.assertEqual(doc_embs[0], [0.2, 0.3, 0.4])
64
+
65
+ @patch("src.vector_store.genai_client")
66
+ def test_generate_embeddings_with_timer(self, mock_genai_client):
67
  """Test embedding generation with timer"""
68
+ # Mock embeddings
69
+ mock_embedding = Mock()
70
+ mock_embedding.values = [0.1, 0.2, 0.3]
71
+ mock_response = Mock()
72
+ mock_response.embeddings = [mock_embedding]
73
+ mock_genai_client.models.embed_content.return_value = mock_response
74
+
75
  mock_timer = Mock()
76
  mock_timer.time_step = MagicMock()
77
  mock_timer.time_step.return_value.__enter__ = Mock()
78
  mock_timer.time_step.return_value.__exit__ = Mock()
79
+
80
  generate_embeddings("Test", self.mock_documents, timer=mock_timer)
81
+
82
  # Verify timer was used
83
  mock_timer.time_step.assert_called_once_with("embedding_generation")
84
+
85
+ @patch("src.vector_store.genai_client")
86
+ def test_generate_embeddings_multiple_docs(self, mock_genai_client):
87
  """Test embedding generation with multiple documents"""
88
  # Create multiple mock documents
89
  mock_doc2 = Mock()
90
  mock_doc2.page_content = "Second document"
91
  docs = [self.mock_doc, mock_doc2]
92
+
93
  # Mock embeddings
94
+ mock_query_emb = Mock()
95
+ mock_query_emb.values = [0.1, 0.2, 0.3]
96
+ mock_doc1_emb = Mock()
97
+ mock_doc1_emb.values = [0.2, 0.3, 0.4]
98
+ mock_doc2_emb = Mock()
99
+ mock_doc2_emb.values = [0.3, 0.4, 0.5]
100
+
101
+ # First call for query, second call for both docs
102
+ call_count = [0]
103
+ def embed_side_effect(*args, **kwargs):
104
+ call_count[0] += 1
105
+ mock_response = Mock()
106
+ if call_count[0] == 1:
107
+ mock_response.embeddings = [mock_query_emb]
108
+ else:
109
+ mock_response.embeddings = [mock_doc1_emb, mock_doc2_emb]
110
+ return mock_response
111
+
112
+ mock_genai_client.models.embed_content.side_effect = embed_side_effect
113
+
114
  query_emb, doc_embs = _generate_embeddings_impl("Test", docs)
115
+
116
  # Should have 2 doc embeddings
117
  self.assertEqual(len(doc_embs), 2)
118
+ self.assertEqual(mock_genai_client.models.embed_content.call_count, 2)
119
+
120
  def test_calculate_similarity_impl(self):
121
  """Test internal similarity calculation implementation"""
122
  query_embedding = [1.0, 0.0, 0.0]
123
  doc_embeddings = [
124
  [1.0, 0.0, 0.0], # Same as query - score should be ~1.0
125
  [0.0, 1.0, 0.0], # Orthogonal - score should be ~0.0
126
+ [0.5, 0.5, 0.0], # Partial similarity
127
  ]
128
+
129
  scores = _calculate_similarity_impl(query_embedding, doc_embeddings)
130
+
131
  # Check scores
132
  self.assertEqual(len(scores), 3)
133
  self.assertAlmostEqual(scores[0], 1.0, places=5)
134
  self.assertAlmostEqual(scores[1], 0.0, places=5)
135
  self.assertGreater(scores[2], 0.0)
136
  self.assertLess(scores[2], 1.0)
137
+
138
  def test_calculate_similarity_with_timer(self):
139
  """Test similarity calculation with timer"""
140
  mock_timer = Mock()
141
  mock_timer.time_step = MagicMock()
142
  mock_timer.time_step.return_value.__enter__ = Mock()
143
  mock_timer.time_step.return_value.__exit__ = Mock()
144
+
145
  query_emb = [1.0, 0.0, 0.0]
146
  doc_embs = [[1.0, 0.0, 0.0]]
147
+
148
  calculate_similarity(query_emb, doc_embs, timer=mock_timer)
149
+
150
  # Verify timer was used
151
  mock_timer.time_step.assert_called_once_with("similarity_calculation")
152
+
153
  def test_process_context_impl(self):
154
  """Test internal context processing implementation"""
155
  # Create mock results with metadata
 
157
  for i in range(3):
158
  mock_result = Mock()
159
  mock_result.metadata = {
160
+ "id": f"KB00{i+1}",
161
+ "question": f"Question {i+1}?",
162
+ "content": f"Answer {i+1}.",
163
  }
164
  results.append(mock_result)
165
+
166
  # Cosine scores (sorted: 0.9, 0.7, 0.5)
167
  cosine_scores = [0.7, 0.5, 0.9]
168
+
169
  context, source_ids, knowledge_pairs = _process_context_impl(
170
  results, cosine_scores, max_results=2
171
  )
172
+
173
  # Should return top 2 results
174
  self.assertEqual(len(source_ids), 2)
175
  self.assertEqual(len(knowledge_pairs), 2)
176
+
177
  # Check that highest score (0.9, index 2) is first
178
+ self.assertEqual(source_ids[0], "KB003")
179
+ self.assertEqual(knowledge_pairs[0][0], "Question 3?")
180
+
181
  # Check formatted context
182
  self.assertIn("Knowledge Entry 1:", context)
183
  self.assertIn("Knowledge Entry 2:", context)
184
  self.assertIn("Q: Question 3?", context)
185
  self.assertIn("A: Answer 3.", context)
186
+
187
  def test_process_context_with_timer(self):
188
  """Test context processing with timer"""
189
  mock_result = Mock()
190
+ mock_result.metadata = {"id": "KB001", "question": "Q?", "content": "A."}
191
+
192
  mock_timer = Mock()
193
  mock_timer.time_step = MagicMock()
194
  mock_timer.time_step.return_value.__enter__ = Mock()
195
  mock_timer.time_step.return_value.__exit__ = Mock()
196
+
197
  process_context([mock_result], [0.9], timer=mock_timer)
198
+
199
  # Verify timer was used
200
  mock_timer.time_step.assert_called_once_with("context_processing")
201
+
202
  def test_process_context_max_results(self):
203
  """Test that max_results parameter limits output"""
204
  # Create 5 mock results
 
206
  for i in range(5):
207
  mock_result = Mock()
208
  mock_result.metadata = {
209
+ "id": f"KB00{i}",
210
+ "question": f"Q{i}?",
211
+ "content": f"A{i}.",
212
  }
213
  results.append(mock_result)
214
+
215
  scores = [0.9, 0.8, 0.7, 0.6, 0.5]
216
+
217
  # Request only 3 results
218
  context, source_ids, knowledge_pairs = _process_context_impl(
219
  results, scores, max_results=3
220
  )
221
+
222
  # Should only return 3
223
  self.assertEqual(len(source_ids), 3)
224
  self.assertEqual(len(knowledge_pairs), 3)
225
+
226
  def test_process_context_formatting(self):
227
  """Test context formatting details"""
228
  mock_result = Mock()
229
  mock_result.metadata = {
230
+ "id": "KB001",
231
+ "question": "Test question?",
232
+ "content": "Test answer.",
233
  }
234
+
235
  context, _, _ = _process_context_impl([mock_result], [0.9], max_results=1)
236
+
237
  # Check formatting
238
  self.assertIn("Knowledge Entry 1:", context)
239
  self.assertIn("Q: Test question?", context)
240
  self.assertIn("A: Test answer.", context)
241
  self.assertIn("-" * 40, context)
242
+
243
  def test_process_context_missing_metadata(self):
244
  """Test context processing with missing metadata fields"""
245
  mock_result = Mock()
246
  mock_result.metadata = {} # No metadata
247
+
248
  context, source_ids, knowledge_pairs = _process_context_impl(
249
  [mock_result], [0.9], max_results=1
250
  )
251
+
252
  # Should handle missing fields with N/A
253
  self.assertIn("N/A", context)
254
  self.assertEqual(source_ids[0], "N/A")
255
 
256
+ @patch("src.vector_store.get_knowledge_base_data")
257
+ @patch("src.vector_store.chromadb.PersistentClient")
258
+ @patch("src.vector_store.Chroma")
259
+ def test_initialize_vector_store_new_collection(
260
+ self, mock_chroma_class, mock_client_class, mock_get_kb
261
+ ):
262
+ """Test initializing vector store with new collection"""
263
+ # Mock knowledge base data
264
+ mock_get_kb.return_value = (
265
+ ["doc1", "doc2"],
266
+ [{"id": "1"}, {"id": "2"}],
267
+ ["id1", "id2"],
268
+ )
269
+
270
+ # Mock ChromaDB client
271
+ mock_client = Mock()
272
+ mock_client_class.return_value = mock_client
273
+
274
+ # Simulate collection doesn't exist (raises exception)
275
+ mock_client.get_collection.side_effect = Exception("Collection not found")
276
+
277
+ # Mock create_collection
278
+ mock_collection = Mock()
279
+ mock_client.create_collection.return_value = mock_collection
280
+
281
+ # Mock Chroma vector store
282
+ mock_vector_store = Mock()
283
+ mock_retriever = Mock()
284
+ mock_vector_store.as_retriever.return_value = mock_retriever
285
+ mock_chroma_class.return_value = mock_vector_store
286
+
287
+ # Call function
288
+ from src.vector_store import initialize_vector_store
289
+
290
+ collection, vector_store, retriever = initialize_vector_store()
291
+
292
+ # Verify collection was created
293
+ mock_client.create_collection.assert_called_once()
294
+ mock_collection.add.assert_called_once()
295
+
296
+ # Verify vector store and retriever
297
+ self.assertEqual(vector_store, mock_vector_store)
298
+ self.assertEqual(retriever, mock_retriever)
299
+
300
+ @patch("src.vector_store.get_knowledge_base_data")
301
+ @patch("src.vector_store.chromadb.PersistentClient")
302
+ @patch("src.vector_store.Chroma")
303
+ def test_initialize_vector_store_existing_collection(
304
+ self, mock_chroma_class, mock_client_class, mock_get_kb
305
+ ):
306
+ """Test initializing vector store with existing collection"""
307
+ # Mock knowledge base data
308
+ mock_get_kb.return_value = (
309
+ ["doc1", "doc2"],
310
+ [{"id": "1"}, {"id": "2"}],
311
+ ["id1", "id2"],
312
+ )
313
+
314
+ # Mock ChromaDB client
315
+ mock_client = Mock()
316
+ mock_client_class.return_value = mock_client
317
+
318
+ # Simulate collection exists
319
+ mock_collection = Mock()
320
+ mock_client.get_collection.return_value = mock_collection
321
+
322
+ # Mock Chroma vector store
323
+ mock_vector_store = Mock()
324
+ mock_retriever = Mock()
325
+ mock_vector_store.as_retriever.return_value = mock_retriever
326
+ mock_chroma_class.return_value = mock_vector_store
327
+
328
+ # Call function
329
+ from src.vector_store import initialize_vector_store
330
+
331
+ collection, vector_store, retriever = initialize_vector_store()
332
+
333
+ # Verify existing collection was loaded (not created)
334
+ mock_client.get_collection.assert_called_once()
335
+ mock_client.create_collection.assert_not_called()
336
+
337
+ # Verify vector store and retriever
338
+ self.assertEqual(collection, mock_collection)
339
+ self.assertEqual(vector_store, mock_vector_store)
340
+ self.assertEqual(retriever, mock_retriever)
341
+
342
+ @patch("src.vector_store.get_knowledge_base_data")
343
+ @patch("src.vector_store.chromadb.PersistentClient")
344
+ def test_initialize_vector_store_failure(self, mock_client_class, mock_get_kb):
345
+ """Test initialize_vector_store handles errors properly"""
346
+ # Mock knowledge base data
347
+ mock_get_kb.return_value = (["doc1"], [{"id": "1"}], ["id1"])
348
+
349
+ # Mock client to raise exception
350
+ mock_client_class.side_effect = Exception("Database connection failed")
351
+
352
+ # Call function and expect exception
353
+ from src.vector_store import initialize_vector_store
354
+
355
+ with self.assertRaises(Exception) as context:
356
+ initialize_vector_store()
357
+
358
+ self.assertIn("Database connection failed", str(context.exception))
359
+
360
 
361
+ if __name__ == "__main__":
362
  unittest.main()
tox.ini ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tox]
2
+ envlist = py310,py311,format,lint
3
+ skipsdist = True
4
+
5
+ [testenv]
6
+ deps = -r requirements.txt
7
+ commands = pytest {posargs}
8
+
9
+ [testenv:format]
10
+ deps =
11
+ black
12
+ isort
13
+ autoflake
14
+ commands =
15
+ autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive src tests app.py
16
+ black src tests app.py
17
+ isort src tests app.py
18
+
19
+ [testenv:lint]
20
+ deps =
21
+ flake8
22
+ pylint
23
+ commands =
24
+ flake8 src tests
25
+ pylint src tests