Jaheen07 commited on
Commit
0b4a84c
·
verified ·
1 Parent(s): 67f4464

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +30 -0
  3. app.py +313 -0
  4. chatbot.py +1015 -0
  5. data/policies.pdf +3 -0
  6. requirements +11 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/policies.pdf filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies (Java needed for tabula)
6
+ RUN apt-get update && apt-get install -y \
7
+ default-jre \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements and install
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application files
15
+ COPY app.py .
16
+ COPY chatbot.py .
17
+ COPY data/ ./data/
18
+
19
+ # Create output directory
20
+ RUN mkdir -p /app/output
21
+
22
+ # Expose port
23
+ EXPOSE 7860
24
+
25
+ # Health check
26
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
27
+ CMD python -c "import requests; requests.get('http://localhost:7860/api/health')"
28
+
29
+ # Run application
30
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List, Dict
5
+ import os
6
+ from datetime import datetime
7
+ import logging
8
+ import threading
9
+
10
+ from chatbot import RAGChatbot
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ app = FastAPI(
16
+ title="RAG Chatbot API - Multi-User",
17
+ description="HR Assistant Chatbot with Per-User Session Management",
18
+ version="2.0.0",
19
+ docs_url="/docs",
20
+ redoc_url="/redoc"
21
+ )
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # Global base chatbot instance
32
+ base_chatbot = None
33
+
34
+ # Per-user session storage
35
+ user_sessions = {}
36
+ session_lock = threading.Lock()
37
+
38
+ # Configuration
39
+ MAX_SESSIONS = 100
40
+ SESSION_TIMEOUT = 3600 # 1 hour
41
+
42
+
43
+ class UserSession:
44
+ """Isolated session for each user"""
45
+
46
+ def __init__(self, user_id: str):
47
+ self.user_id = user_id
48
+ self.chat_history = []
49
+ self.conversation_context = {
50
+ 'current_employee': None,
51
+ 'last_mentioned_entities': []
52
+ }
53
+ self.last_activity = datetime.now()
54
+
55
+ def update_activity(self):
56
+ self.last_activity = datetime.now()
57
+
58
+
59
+ def cleanup_old_sessions():
60
+ """Remove inactive sessions"""
61
+ with session_lock:
62
+ current_time = datetime.now()
63
+ to_remove = []
64
+
65
+ for user_id, session in user_sessions.items():
66
+ time_diff = (current_time - session.last_activity).total_seconds()
67
+ if time_diff > SESSION_TIMEOUT:
68
+ to_remove.append(user_id)
69
+
70
+ for user_id in to_remove:
71
+ del user_sessions[user_id]
72
+ logger.info(f"Cleaned up session for user: {user_id}")
73
+
74
+
75
+ def get_or_create_session(user_id: str) -> UserSession:
76
+ """Get existing session or create new one"""
77
+ with session_lock:
78
+ if len(user_sessions) > MAX_SESSIONS:
79
+ cleanup_old_sessions()
80
+
81
+ if user_id not in user_sessions:
82
+ user_sessions[user_id] = UserSession(user_id)
83
+ logger.info(f"Created new session for user: {user_id}")
84
+
85
+ session = user_sessions[user_id]
86
+ session.update_activity()
87
+ return session
88
+
89
+
90
+ # Pydantic models
91
+ class ChatRequest(BaseModel):
92
+ question: str
93
+ user_id: str
94
+
95
+
96
+ class ChatResponse(BaseModel):
97
+ question: str
98
+ answer: str
99
+ timestamp: str
100
+ user_id: str
101
+ session_info: Dict
102
+
103
+
104
+ @app.on_event("startup")
105
+ async def startup_event():
106
+ global base_chatbot
107
+
108
+ logger.info("=== Starting RAG Chatbot Initialization ===")
109
+
110
+ try:
111
+ PDF_PATH = os.getenv("PDF_PATH", "/app/data/policies.pdf")
112
+ HF_TOKEN = os.getenv("HF_TOKEN")
113
+
114
+ if not HF_TOKEN:
115
+ raise ValueError("HF_TOKEN environment variable not set")
116
+
117
+ logger.info(f"PDF Path: {PDF_PATH}")
118
+ logger.info(f"File exists: {os.path.exists(PDF_PATH)}")
119
+
120
+ if not os.path.exists(PDF_PATH):
121
+ raise ValueError(f"PDF file not found at {PDF_PATH}")
122
+
123
+ base_chatbot = RAGChatbot(PDF_PATH, HF_TOKEN)
124
+ logger.info("=== Base chatbot initialized successfully! ===")
125
+
126
+ except Exception as e:
127
+ logger.error(f"Failed to initialize chatbot: {e}")
128
+ raise
129
+
130
+
131
+ @app.get("/")
132
+ async def root():
133
+ return {
134
+ "service": "RAG Chatbot API",
135
+ "version": "2.0.0",
136
+ "status": "healthy",
137
+ "active_sessions": len(user_sessions),
138
+ "chatbot_loaded": base_chatbot is not None,
139
+ "endpoints": {
140
+ "docs": "/docs",
141
+ "chat": "POST /api/chat",
142
+ "history": "GET /api/history/{user_id}",
143
+ "reset": "POST /api/reset?user_id=xxx",
144
+ "sessions": "GET /api/sessions"
145
+ }
146
+ }
147
+
148
+
149
+ @app.get("/api/health")
150
+ async def health_check():
151
+ if base_chatbot is None:
152
+ raise HTTPException(status_code=503, detail="Chatbot not initialized")
153
+
154
+ return {
155
+ "status": "healthy",
156
+ "timestamp": datetime.now().isoformat(),
157
+ "chatbot_ready": True,
158
+ "active_sessions": len(user_sessions)
159
+ }
160
+
161
+
162
+ @app.post("/api/chat", response_model=ChatResponse)
163
+ async def chat(request: ChatRequest):
164
+ """Send a question to the chatbot with user session isolation"""
165
+ if base_chatbot is None:
166
+ raise HTTPException(status_code=503, detail="Chatbot not initialized")
167
+
168
+ if not request.question.strip():
169
+ raise HTTPException(status_code=400, detail="Question cannot be empty")
170
+
171
+ if not request.user_id:
172
+ raise HTTPException(status_code=400, detail="user_id is required")
173
+
174
+ try:
175
+ logger.info(f"User {request.user_id}: {request.question[:50]}...")
176
+
177
+ # Get user session
178
+ session = get_or_create_session(request.user_id)
179
+
180
+ # Resolve pronouns using user's context
181
+ resolved_question = base_chatbot._resolve_pronouns_for_session(
182
+ request.question,
183
+ session.conversation_context
184
+ )
185
+
186
+ # Retrieve relevant chunks
187
+ retrieved_data = base_chatbot._retrieve(resolved_question, k=20)
188
+
189
+ # Search user's chat history (not global)
190
+ relevant_past_chats = base_chatbot._search_session_history(
191
+ resolved_question,
192
+ session.chat_history,
193
+ k=5
194
+ )
195
+
196
+ # Build prompt with user's context
197
+ prompt = base_chatbot._build_prompt_for_session(
198
+ resolved_question,
199
+ retrieved_data,
200
+ relevant_past_chats,
201
+ session.chat_history,
202
+ session.conversation_context
203
+ )
204
+
205
+ # Generate response
206
+ messages = [{"role": "user", "content": prompt}]
207
+
208
+ response = base_chatbot.llm_client.chat_completion(
209
+ messages=messages,
210
+ model="meta-llama/Llama-3.1-8B-Instruct",
211
+ max_tokens=512,
212
+ temperature=0.3
213
+ )
214
+
215
+ answer = response.choices[0].message.content
216
+
217
+ # Update user's conversation context
218
+ base_chatbot._update_conversation_context_for_session(
219
+ request.question,
220
+ answer,
221
+ session.conversation_context
222
+ )
223
+
224
+ # Store in user's history
225
+ chat_entry = {
226
+ 'timestamp': datetime.now().isoformat(),
227
+ 'question': request.question,
228
+ 'answer': answer,
229
+ 'used_past_context': len(relevant_past_chats) > 0
230
+ }
231
+ session.chat_history.append(chat_entry)
232
+
233
+ response_data = ChatResponse(
234
+ question=request.question,
235
+ answer=answer,
236
+ timestamp=datetime.now().isoformat(),
237
+ user_id=request.user_id,
238
+ session_info={
239
+ 'total_messages': len(session.chat_history),
240
+ 'current_context': session.conversation_context.get('current_employee')
241
+ }
242
+ )
243
+
244
+ logger.info(f"User {request.user_id}: Question processed successfully")
245
+ return response_data
246
+
247
+ except Exception as e:
248
+ logger.error(f"Error for user {request.user_id}: {e}")
249
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
250
+
251
+
252
+ @app.post("/api/reset")
253
+ async def reset_chat(user_id: str):
254
+ """Reset chat history for specific user"""
255
+ if not user_id:
256
+ raise HTTPException(status_code=400, detail="user_id is required")
257
+
258
+ with session_lock:
259
+ if user_id in user_sessions:
260
+ del user_sessions[user_id]
261
+ logger.info(f"Reset session for user: {user_id}")
262
+ return {"message": f"Chat history reset for user {user_id}", "status": "success"}
263
+ else:
264
+ return {"message": f"No session found for user {user_id}", "status": "success"}
265
+
266
+
267
+ @app.get("/api/history/{user_id}")
268
+ async def get_history(user_id: str):
269
+ """Get chat history for specific user"""
270
+ session = get_or_create_session(user_id)
271
+
272
+ return {
273
+ "user_id": user_id,
274
+ "total_conversations": len(session.chat_history),
275
+ "current_context": session.conversation_context.get('current_employee'),
276
+ "history": session.chat_history
277
+ }
278
+
279
+
280
+ @app.get("/api/sessions")
281
+ async def get_active_sessions():
282
+ """Get list of active sessions"""
283
+ with session_lock:
284
+ return {
285
+ "total_sessions": len(user_sessions),
286
+ "max_sessions": MAX_SESSIONS,
287
+ "session_timeout_seconds": SESSION_TIMEOUT,
288
+ "sessions": [
289
+ {
290
+ "user_id": user_id,
291
+ "messages": len(session.chat_history),
292
+ "last_activity": session.last_activity.isoformat(),
293
+ "current_context": session.conversation_context.get('current_employee')
294
+ }
295
+ for user_id, session in user_sessions.items()
296
+ ]
297
+ }
298
+
299
+
300
+ @app.post("/api/cleanup")
301
+ async def manual_cleanup():
302
+ """Manually trigger session cleanup"""
303
+ cleanup_old_sessions()
304
+ return {
305
+ "message": "Cleanup completed",
306
+ "active_sessions": len(user_sessions)
307
+ }
308
+
309
+
310
+ if __name__ == "__main__":
311
+ import uvicorn
312
+
313
+ uvicorn.run(app, host="0.0.0.0", port=7860)
chatbot.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Chatbot with Separate Table and Text Processing + Reinforcement Learning from Chat History
2
+ import PyPDF2
3
+ import faiss
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from huggingface_hub import InferenceClient
7
+ from typing import List, Tuple, Dict
8
+ import json
9
+ import re
10
+ import pandas as pd
11
+ import tabula.io as tabula
12
+ import os
13
+ import pickle
14
+ from datetime import datetime
15
+ from collections import Counter
16
+
17
+
18
+ class RAGChatbot:
19
+ def __init__(self, pdf_path: str, hf_token: str):
20
+ self.pdf_path = pdf_path
21
+ self.hf_token = hf_token
22
+ self.chunks = []
23
+ self.chunk_metadata = []
24
+ self.index = None
25
+ self.embeddings_model = None
26
+ self.llm_client = None
27
+ self.chat_history = []
28
+ self.output_dir = "./"
29
+ self.table_csv_path = None
30
+ self.text_chunks_path = None
31
+ self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
32
+
33
+ # Chat history embeddings and index
34
+ self.chat_embeddings = []
35
+ self.chat_index = None
36
+ self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
37
+
38
+ # Learning statistics
39
+ self.query_patterns = Counter()
40
+ self.feedback_scores = {}
41
+ self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
42
+
43
+ # ADD THIS NEW SECTION:
44
+ self.conversation_context = {
45
+ 'current_employee': None,
46
+ 'last_mentioned_entities': []
47
+ }
48
+
49
+ os.makedirs(self.output_dir, exist_ok=True)
50
+
51
+ # Load existing chat history and learning data
52
+ self._load_chat_history()
53
+ self._load_learning_stats()
54
+
55
+ self._setup()
56
+
57
+ # Build chat history index after setup
58
+ self._build_chat_history_index()
59
+
60
+ def _load_chat_history(self):
61
+ """Load chat history from file if exists"""
62
+ if os.path.exists(self.history_file):
63
+ try:
64
+ with open(self.history_file, 'rb') as f:
65
+ self.chat_history = pickle.load(f)
66
+ print(f"Loaded {len(self.chat_history)} previous conversations")
67
+ except Exception as e:
68
+ print(f"Could not load chat history: {e}")
69
+ self.chat_history = []
70
+ else:
71
+ self.chat_history = []
72
+
73
+ def _save_chat_history(self):
74
+ """Save chat history to file"""
75
+ try:
76
+ with open(self.history_file, 'wb') as f:
77
+ pickle.dump(self.chat_history, f)
78
+ except Exception as e:
79
+ print(f"Could not save chat history: {e}")
80
+
81
+ def _load_learning_stats(self):
82
+ """Load learning statistics"""
83
+ if os.path.exists(self.stats_file):
84
+ try:
85
+ with open(self.stats_file, 'rb') as f:
86
+ data = pickle.load(f)
87
+ self.query_patterns = data.get('query_patterns', Counter())
88
+ self.feedback_scores = data.get('feedback_scores', {})
89
+ print(f"Loaded learning statistics: {len(self.query_patterns)} patterns tracked")
90
+ except Exception as e:
91
+ print(f"Could not load learning stats: {e}")
92
+ self.query_patterns = Counter()
93
+ self.feedback_scores = {}
94
+ else:
95
+ self.query_patterns = Counter()
96
+ self.feedback_scores = {}
97
+
98
+ def _save_learning_stats(self):
99
+ """Save learning statistics"""
100
+ try:
101
+ with open(self.stats_file, 'wb') as f:
102
+ pickle.dump({
103
+ 'query_patterns': self.query_patterns,
104
+ 'feedback_scores': self.feedback_scores
105
+ }, f)
106
+ except Exception as e:
107
+ print(f"Could not save learning stats: {e}")
108
+
109
+ def _build_chat_history_index(self):
110
+ """Build FAISS index from chat history for semantic search"""
111
+ if len(self.chat_history) == 0:
112
+ print("No chat history to index")
113
+ return
114
+
115
+ print(f"Building semantic index for {len(self.chat_history)} past conversations...")
116
+
117
+ # Create embeddings for all past Q&A pairs
118
+ chat_texts = []
119
+ for entry in self.chat_history:
120
+ # Combine question and answer for better context
121
+ combined_text = f"Q: {entry['question']}\nA: {entry['answer']}"
122
+ chat_texts.append(combined_text)
123
+
124
+ # Generate embeddings
125
+ self.chat_embeddings = self.embeddings_model.encode(chat_texts, show_progress_bar=True)
126
+
127
+ # Build FAISS index
128
+ dimension = self.chat_embeddings.shape[1]
129
+ self.chat_index = faiss.IndexFlatL2(dimension)
130
+ self.chat_index.add(np.array(self.chat_embeddings).astype('float32'))
131
+
132
+ # Save embeddings
133
+ try:
134
+ with open(self.chat_embedding_file, 'wb') as f:
135
+ pickle.dump(self.chat_embeddings, f)
136
+ except Exception as e:
137
+ print(f"Could not save chat embeddings: {e}")
138
+
139
+ print(f"Chat history index built successfully")
140
+
141
+ def _search_chat_history(self, query: str, k: int = 5) -> List[Dict]:
142
+ """Search through past conversations semantically"""
143
+ if self.chat_index is None or len(self.chat_history) == 0:
144
+ return []
145
+
146
+ # Encode query
147
+ query_embedding = self.embeddings_model.encode([query])
148
+
149
+ # Search
150
+ distances, indices = self.chat_index.search(
151
+ np.array(query_embedding).astype('float32'),
152
+ min(k, len(self.chat_history))
153
+ )
154
+
155
+ # Return relevant past conversations
156
+ relevant_chats = []
157
+ for idx, distance in zip(indices[0], distances[0]):
158
+ if distance < 1.5: # Similarity threshold
159
+ relevant_chats.append({
160
+ 'chat': self.chat_history[idx],
161
+ 'similarity_score': float(distance)
162
+ })
163
+
164
+ return relevant_chats
165
+
166
+ def _extract_entities_from_query(self, query: str) -> Dict:
167
+ """Extract names and entities from query"""
168
+ query_lower = query.lower()
169
+
170
+ # Check for pronouns that need context
171
+ has_pronoun = bool(re.search(r'\b(his|her|their|he|she|they|him|them)\b', query_lower))
172
+
173
+ # Try to extract names (capitalize words that might be names)
174
+ potential_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', query)
175
+
176
+ return {
177
+ 'has_pronoun': has_pronoun,
178
+ 'names': potential_names
179
+ }
180
+
181
+ def _update_conversation_context(self, question: str, answer: str):
182
+ """Update context tracking based on conversation"""
183
+ # Extract names from question
184
+ names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
185
+
186
+ # Extract names from answer
187
+ answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
188
+
189
+ # Update current employee if employee was mentioned
190
+ if 'employee' in answer.lower() or 'working' in answer.lower():
191
+ all_names = names + answer_names
192
+ if all_names:
193
+ self.conversation_context['current_employee'] = all_names[0]
194
+ # Keep last 5 mentioned entities
195
+ self.conversation_context['last_mentioned_entities'] = (
196
+ all_names[:5] if len(all_names) <= 5
197
+ else self.conversation_context['last_mentioned_entities'][-4:] + [all_names[0]]
198
+ )
199
+
200
+ def _resolve_pronouns(self, query: str) -> str:
201
+ """Replace pronouns with actual entity names from context"""
202
+ entities = self._extract_entities_from_query(query)
203
+
204
+ if entities['has_pronoun'] and self.conversation_context['current_employee']:
205
+ current_name = self.conversation_context['current_employee']
206
+
207
+ # Replace pronouns with the current employee name
208
+ query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
209
+ query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
210
+ query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
211
+ query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
212
+
213
+ return query
214
+
215
+
216
+ def _extract_query_pattern(self, query: str) -> str:
217
+ """Extract pattern from query for learning"""
218
+ query_lower = query.lower()
219
+
220
+ # Detect common patterns
221
+ patterns = []
222
+
223
+ if re.search(r'\bhow many\b', query_lower):
224
+ patterns.append('count_query')
225
+ if re.search(r'\bwho\b', query_lower):
226
+ patterns.append('who_query')
227
+ if re.search(r'\bwhat\b', query_lower):
228
+ patterns.append('what_query')
229
+ if re.search(r'\bwhen\b', query_lower):
230
+ patterns.append('when_query')
231
+ if re.search(r'\bwhere\b', query_lower):
232
+ patterns.append('where_query')
233
+ if re.search(r'\blist\b|\ball\b', query_lower):
234
+ patterns.append('list_query')
235
+ if re.search(r'\bcalculate\b|\bsum\b|\btotal\b|\baverage\b', query_lower):
236
+ patterns.append('calculation_query')
237
+ if re.search(r'\bemployee\b|\bstaff\b|\bworker\b', query_lower):
238
+ patterns.append('employee_query')
239
+ if re.search(r'\bpolicy\b|\brule\b|\bguideline\b', query_lower):
240
+ patterns.append('policy_query')
241
+
242
+ return '|'.join(patterns) if patterns else 'general_query'
243
+
244
+ def _load_pdf_text(self) -> str:
245
+ """Load text from PDF"""
246
+ text = ""
247
+ with open(self.pdf_path, 'rb') as file:
248
+ pdf_reader = PyPDF2.PdfReader(file)
249
+ for page in pdf_reader.pages:
250
+ text += page.extract_text()
251
+ return text
252
+
253
+ def _extract_and_merge_tables(self) -> str:
254
+ """Extract all tables from PDF and merge into single CSV"""
255
+ try:
256
+ print("Extracting tables from PDF...")
257
+
258
+ # Extract all tables
259
+ dfs = tabula.read_pdf(self.pdf_path, pages="all", multiple_tables=True)
260
+
261
+ if not dfs or len(dfs) == 0:
262
+ print("No tables found in PDF")
263
+ return None
264
+
265
+ print(f"Found {len(dfs)} tables")
266
+
267
+ # The first table has headers
268
+ merged_df = dfs[0]
269
+
270
+ # Append rest of the tables
271
+ for i in range(1, len(dfs)):
272
+ # Set the column names to match the first table
273
+ dfs[i].columns = merged_df.columns
274
+ # Append rows
275
+ merged_df = pd.concat([merged_df, dfs[i]], ignore_index=True)
276
+
277
+ # Save merged table
278
+ csv_path = os.path.join(self.output_dir, "merged_employee_tables.csv")
279
+ merged_df.to_csv(csv_path, index=False)
280
+
281
+ print(f"Merged {len(dfs)} tables into {csv_path}")
282
+ print(f"Total rows: {len(merged_df)}")
283
+ print(f"Columns: {list(merged_df.columns)}")
284
+
285
+ return csv_path
286
+
287
+ except Exception as e:
288
+ print(f"Table extraction failed: {e}")
289
+ return None
290
+
291
+ def _save_table_chunks(self, table_chunks: List[Dict]) -> str:
292
+ """Save table chunks (full table + row chunks) to a text file"""
293
+ save_path = os.path.join(self.output_dir, "table_chunks.txt")
294
+
295
+ with open(save_path, 'w', encoding='utf-8') as f:
296
+ f.write(f"Total Table Chunks: {len(table_chunks)}\n")
297
+ f.write("=" * 80 + "\n\n")
298
+
299
+ for i, chunk in enumerate(table_chunks):
300
+ f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
301
+ f.write("-" * 80 + "\n")
302
+ f.write(chunk['content'])
303
+ f.write("\n\n" + "=" * 80 + "\n\n")
304
+
305
+ print(f"Saved {len(table_chunks)} table chunks to {save_path}")
306
+ return save_path
307
+
308
+ def _detect_table_regions_in_text(self, text: str) -> List[Tuple[int, int]]:
309
+ """Detect start and end positions of table regions in text"""
310
+ lines = text.split('\n')
311
+ table_regions = []
312
+ start_idx = None
313
+
314
+ for i, line in enumerate(lines):
315
+ is_table_line = (
316
+ '@' in line or
317
+ re.search(r'\b(A|B|AB|O)[+-]\b', line) or
318
+ re.search(r'\s{3,}', line) or
319
+ re.search(r'Employee Name|Email|Position|Table|Blood Group', line, re.IGNORECASE)
320
+ )
321
+
322
+ if is_table_line:
323
+ if start_idx is None:
324
+ start_idx = i
325
+ else:
326
+ if start_idx is not None:
327
+ # End of table region
328
+ if i - start_idx > 3: # Only consider tables with 3+ lines
329
+ table_regions.append((start_idx, i))
330
+ start_idx = None
331
+
332
+ # Handle last table if exists
333
+ if start_idx is not None and len(lines) - start_idx > 3:
334
+ table_regions.append((start_idx, len(lines)))
335
+
336
+ return table_regions
337
+
338
+ def _remove_table_text(self, text: str) -> str:
339
+ """Remove table content from text"""
340
+ lines = text.split('\n')
341
+ table_regions = self._detect_table_regions_in_text(text)
342
+
343
+ # Create set of line indices to remove
344
+ lines_to_remove = set()
345
+ for start, end in table_regions:
346
+ for i in range(start, end):
347
+ lines_to_remove.add(i)
348
+
349
+ # Keep only non-table lines
350
+ clean_lines = [line for i, line in enumerate(lines) if i not in lines_to_remove]
351
+
352
+ return '\n'.join(clean_lines)
353
+
354
+ def _chunk_text_content(self, text: str) -> List[Dict]:
355
+ """Chunk text content (Q&A pairs and other text)"""
356
+ chunks = []
357
+
358
+ # Remove table text
359
+ clean_text = self._remove_table_text(text)
360
+
361
+ # Split by ###Question###
362
+ qa_pairs = clean_text.split('###Question###')
363
+
364
+ for i, qa in enumerate(qa_pairs):
365
+ if not qa.strip():
366
+ continue
367
+
368
+ if '###Answer###' in qa:
369
+ chunk_text = '###Question###' + qa
370
+ if len(chunk_text) > 50:
371
+ chunks.append({
372
+ 'content': chunk_text,
373
+ 'type': 'qa',
374
+ 'source': 'text_content',
375
+ 'chunk_id': f'qa_{i}'
376
+ })
377
+
378
+ # Also create chunks from sections (for non-Q&A content)
379
+ sections = re.split(r'\n\n+', clean_text)
380
+ for i, section in enumerate(sections):
381
+ section = section.strip()
382
+ if len(section) > 200 and '###Question###' not in section:
383
+ chunks.append({
384
+ 'content': section,
385
+ 'type': 'text',
386
+ 'source': 'text_content',
387
+ 'chunk_id': f'text_{i}'
388
+ })
389
+
390
+ return chunks
391
+
392
+ def _save_text_chunks(self, chunks: List[Dict]) -> str:
393
+ """Save text chunks to file"""
394
+ text_path = os.path.join(self.output_dir, "text_chunks.txt")
395
+
396
+ with open(text_path, 'w', encoding='utf-8') as f:
397
+ f.write(f"Total Text Chunks: {len(chunks)}\n")
398
+ f.write("=" * 80 + "\n\n")
399
+
400
+ for i, chunk in enumerate(chunks):
401
+ f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
402
+ f.write("-" * 80 + "\n")
403
+ f.write(chunk['content'])
404
+ f.write("\n\n" + "=" * 80 + "\n\n")
405
+
406
+ print(f"Saved {len(chunks)} text chunks to {text_path}")
407
+ return text_path
408
+
409
+ def _load_csv_as_text(self, csv_path: str) -> str:
410
+ """Load CSV and convert to readable text format"""
411
+ try:
412
+ df = pd.read_csv(csv_path)
413
+ text = f"[EMPLOYEE TABLE DATA]\n"
414
+ text += f"Total Employees: {len(df)}\n\n"
415
+ text += df.to_string(index=False)
416
+ return text
417
+ except Exception as e:
418
+ print(f"Error loading CSV: {e}")
419
+ return ""
420
+
421
+ def _create_table_chunks(self, csv_path: str) -> List[Dict]:
422
+ """Create chunks from CSV table"""
423
+ chunks = []
424
+
425
+ try:
426
+ df = pd.read_csv(csv_path)
427
+
428
+ # Create one chunk with full table overview
429
+ full_table_text = f"[COMPLETE EMPLOYEE TABLE]\n"
430
+ full_table_text += f"Total Employees: {len(df)}\n"
431
+ full_table_text += f"Columns: {', '.join(df.columns)}\n\n"
432
+ full_table_text += df.to_string(index=False)
433
+
434
+ chunks.append({
435
+ 'content': full_table_text,
436
+ 'type': 'table_full',
437
+ 'source': 'employee_table.csv',
438
+ 'chunk_id': 'table_full'
439
+ })
440
+
441
+ # Create chunks for each row (employee)
442
+ for idx, row in df.iterrows():
443
+ row_text = f"[EMPLOYEE RECORD {idx + 1}]\n"
444
+ for col in df.columns:
445
+ row_text += f"{col}: {row[col]}\n"
446
+
447
+ chunks.append({
448
+ 'content': row_text,
449
+ 'type': 'table_row',
450
+ 'source': 'employee_table.csv',
451
+ 'chunk_id': f'employee_{idx}'
452
+ })
453
+
454
+ print(f"Created {len(chunks)} chunks from table ({len(df)} employee records + 1 full table)")
455
+
456
+ except Exception as e:
457
+ print(f"Error creating table chunks: {e}")
458
+
459
+ return chunks
460
+
461
+ def _save_manifest(self, all_chunks: List[Dict]):
462
+ """Save manifest of all chunks"""
463
+ manifest = {
464
+ 'total_chunks': len(all_chunks),
465
+ 'chunks_by_type': {
466
+ 'qa': sum(1 for c in all_chunks if c['type'] == 'qa'),
467
+ 'text': sum(1 for c in all_chunks if c['type'] == 'text'),
468
+ 'table_full': sum(1 for c in all_chunks if c['type'] == 'table_full'),
469
+ 'table_row': sum(1 for c in all_chunks if c['type'] == 'table_row')
470
+ },
471
+ 'files_created': {
472
+ 'table_csv': self.table_csv_path,
473
+ 'text_chunks': self.text_chunks_path
474
+ },
475
+ 'chunk_details': [
476
+ {
477
+ 'chunk_id': c['chunk_id'],
478
+ 'type': c['type'],
479
+ 'source': c['source'],
480
+ 'length': len(c['content'])
481
+ }
482
+ for c in all_chunks
483
+ ]
484
+ }
485
+
486
+ manifest_path = os.path.join(self.output_dir, 'chunk_manifest.json')
487
+ with open(manifest_path, 'w', encoding='utf-8') as f:
488
+ json.dump(manifest, f, indent=2, ensure_ascii=False)
489
+
490
+ print(f"Saved manifest to {manifest_path}")
491
+ return manifest_path
492
+
493
+ def _resolve_pronouns_for_session(self, query: str, conversation_context: Dict) -> str:
494
+ """Resolve pronouns using session-specific context"""
495
+ entities = self._extract_entities_from_query(query)
496
+
497
+ if entities['has_pronoun'] and conversation_context.get('current_employee'):
498
+ current_name = conversation_context['current_employee']
499
+
500
+ query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
501
+ query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
502
+ query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
503
+ query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
504
+
505
+ return query
506
+
507
+ def _search_session_history(self, query: str, session_history: List[Dict], k: int = 5) -> List[Dict]:
508
+ """Search through session-specific history"""
509
+ if not session_history:
510
+ return []
511
+
512
+ chat_texts = [f"Q: {entry['question']}\nA: {entry['answer']}" for entry in session_history]
513
+
514
+ if not chat_texts:
515
+ return []
516
+
517
+ chat_embeddings = self.embeddings_model.encode(chat_texts)
518
+
519
+ dimension = chat_embeddings.shape[1]
520
+ temp_index = faiss.IndexFlatL2(dimension)
521
+ temp_index.add(np.array(chat_embeddings).astype('float32'))
522
+
523
+ query_embedding = self.embeddings_model.encode([query])
524
+ distances, indices = temp_index.search(
525
+ np.array(query_embedding).astype('float32'),
526
+ min(k, len(session_history))
527
+ )
528
+
529
+ relevant_chats = []
530
+ for idx, distance in zip(indices[0], distances[0]):
531
+ if distance < 1.5:
532
+ relevant_chats.append({
533
+ 'chat': session_history[idx],
534
+ 'similarity_score': float(distance)
535
+ })
536
+
537
+ return relevant_chats
538
+
539
+ def _build_prompt_for_session(self, query: str, retrieved_data: List[Tuple[str, Dict]],
540
+ relevant_past_chats: List[Dict], session_history: List[Dict],
541
+ conversation_context: Dict) -> str:
542
+ """Build prompt using session-specific data"""
543
+
544
+ employee_records = []
545
+ full_table = []
546
+ qa_context = []
547
+ text_context = []
548
+
549
+ for content, metadata in retrieved_data:
550
+ if metadata['type'] == 'table_row':
551
+ employee_records.append(content)
552
+ elif metadata['type'] == 'table_full':
553
+ full_table.append(content)
554
+ elif metadata['type'] == 'qa':
555
+ qa_context.append(content)
556
+ elif metadata['type'] == 'text':
557
+ text_context.append(content)
558
+
559
+ context_text = ""
560
+ if full_table:
561
+ context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
562
+ if employee_records:
563
+ context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
564
+ if qa_context:
565
+ context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
566
+ if text_context:
567
+ context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
568
+
569
+ context_memory = ""
570
+ if conversation_context.get('current_employee'):
571
+ context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
572
+ context_memory += f"Currently discussing: {conversation_context['current_employee']}\n"
573
+ if conversation_context.get('last_mentioned_entities'):
574
+ context_memory += f"Recently mentioned: {', '.join(conversation_context['last_mentioned_entities'])}\n"
575
+ context_memory += "\n"
576
+
577
+ past_context = ""
578
+ if relevant_past_chats:
579
+ past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
580
+ for i, chat_info in enumerate(relevant_past_chats[:3], 1):
581
+ chat = chat_info['chat']
582
+ past_context += f"\n[Past Q&A {i}]:\n"
583
+ past_context += f"Previous Question: {chat['question']}\n"
584
+ past_context += f"Previous Answer: {chat['answer']}\n"
585
+ past_context += "\n"
586
+
587
+ history_text = ""
588
+ for entry in session_history[-10:]:
589
+ history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
590
+
591
+ prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
592
+
593
+ IMPORTANT INSTRUCTIONS:
594
+ - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
595
+ - For employee-related queries, use the employee data provided
596
+ - If you find any name from user input, always look into the EMPLOYEE TABLE first
597
+ - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in THIS USER's recent conversation
598
+ - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
599
+ - Be careful not to give all employee information - only answer what was asked
600
+ - For counting or calculations, use the table data
601
+ - For policy questions, use the Q&A knowledge base
602
+ - Provide specific, accurate answers based on the context
603
+ - If information is not in the context, say "I don't have this information"
604
+ - Round up any fractional numbers in calculations
605
+
606
+ Context:
607
+ {context_text}
608
+
609
+ {context_memory}
610
+
611
+ {past_context}
612
+
613
+ Recent conversation:
614
+ {history_text}
615
+
616
+ User Question: {query}
617
+
618
+ Answer based on the context above. Be specific and accurate.[/INST]"""
619
+
620
+ return prompt
621
+
622
+ def _update_conversation_context_for_session(self, question: str, answer: str, conversation_context: Dict):
623
+ """Update session-specific conversation context"""
624
+ names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
625
+ answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
626
+
627
+ if 'employee' in answer.lower() or 'working' in answer.lower():
628
+ all_names = names + answer_names
629
+ if all_names:
630
+ conversation_context['current_employee'] = all_names[0]
631
+ conversation_context['last_mentioned_entities'] = (
632
+ all_names[:5] if len(all_names) <= 5
633
+ else conversation_context.get('last_mentioned_entities', [])[-4:] + [all_names[0]]
634
+ )
635
+
636
+ def _setup(self):
637
+ print("\n" + "=" * 80)
638
+ print("STEP 1: Loading PDF")
639
+ print("=" * 80)
640
+
641
+ text = self._load_pdf_text()
642
+ print(f"Loaded PDF with {len(text)} characters")
643
+
644
+ print("\n" + "=" * 80)
645
+ print("STEP 2: Extracting and Merging Tables")
646
+ print("=" * 80)
647
+
648
+ self.table_csv_path = self._extract_and_merge_tables()
649
+
650
+ print("\n" + "=" * 80)
651
+ print("STEP 3: Chunking Text Content (Removing Tables)")
652
+ print("=" * 80)
653
+
654
+ text_chunks = self._chunk_text_content(text)
655
+ self.text_chunks_path = self._save_text_chunks(text_chunks)
656
+
657
+ print("\n" + "=" * 80)
658
+ print("STEP 4: Creating Final Chunks")
659
+ print("=" * 80)
660
+
661
+ all_chunks = []
662
+
663
+ # Add text chunks
664
+ all_chunks.extend(text_chunks)
665
+
666
+ # Add table chunks
667
+ if self.table_csv_path:
668
+ table_chunks = self._create_table_chunks(self.table_csv_path)
669
+ all_chunks.extend(table_chunks)
670
+ # Save chunked table text to file
671
+ self._save_table_chunks(table_chunks)
672
+
673
+ # Extract content and metadata
674
+ self.chunks = [c['content'] for c in all_chunks]
675
+ self.chunk_metadata = all_chunks
676
+
677
+ print(f"\nTotal chunks created: {len(self.chunks)}")
678
+ print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
679
+ print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
680
+ print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
681
+ print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
682
+
683
+ # Save manifest
684
+ self._save_manifest(all_chunks)
685
+
686
+ print("\n" + "=" * 80)
687
+ print("STEP 5: Creating Embeddings")
688
+ print("=" * 80)
689
+
690
+ print("Loading embedding model...")
691
+ self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
692
+
693
+ print("Creating embeddings for all chunks...")
694
+ embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
695
+
696
+ print("Building FAISS index...")
697
+ dimension = embeddings.shape[1]
698
+ self.index = faiss.IndexFlatL2(dimension)
699
+ self.index.add(np.array(embeddings).astype('float32'))
700
+
701
+ print("\n" + "=" * 80)
702
+ print("STEP 6: Initializing LLM")
703
+ print("=" * 80)
704
+
705
+ self.llm_client = InferenceClient(token=self.hf_token)
706
+
707
+ print("\n" + "=" * 80)
708
+ print("SETUP COMPLETE!")
709
+ print("=" * 80)
710
+ print(f"Files created in: {self.output_dir}/")
711
+ print(f" - {os.path.basename(self.table_csv_path) if self.table_csv_path else 'No table CSV'}")
712
+ print(f" - {os.path.basename(self.text_chunks_path)}")
713
+ print(f" - chunk_manifest.json")
714
+ print(f" - {os.path.basename(self.history_file)}")
715
+ print("=" * 80 + "\n")
716
+
717
+ def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
718
+ """Retrieve relevant chunks with metadata"""
719
+ query_embedding = self.embeddings_model.encode([query])
720
+ distances, indices = self.index.search(np.array(query_embedding).astype('float32'), k)
721
+
722
+ results = []
723
+ for idx in indices[0]:
724
+ results.append((self.chunks[idx], self.chunk_metadata[idx]))
725
+
726
+ return results
727
+
728
+ def _build_prompt(self, query: str, retrieved_data: List[Tuple[str, Dict]], relevant_past_chats: List[Dict]) -> str:
729
+ """Build prompt with retrieved context and learned information from past chats"""
730
+
731
+ # Separate different types of context
732
+ employee_records = []
733
+ full_table = []
734
+ qa_context = []
735
+ text_context = []
736
+
737
+ for content, metadata in retrieved_data:
738
+ if metadata['type'] == 'table_row':
739
+ employee_records.append(content)
740
+ elif metadata['type'] == 'table_full':
741
+ full_table.append(content)
742
+ elif metadata['type'] == 'qa':
743
+ qa_context.append(content)
744
+ elif metadata['type'] == 'text':
745
+ text_context.append(content)
746
+
747
+ # Build context sections
748
+ context_text = ""
749
+
750
+ if full_table:
751
+ context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
752
+
753
+ if employee_records:
754
+ context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
755
+
756
+ if qa_context:
757
+ context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
758
+
759
+ if text_context:
760
+ context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
761
+
762
+ # ADD THIS NEW SECTION:
763
+ context_memory = ""
764
+ if self.conversation_context['current_employee']:
765
+ context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
766
+ context_memory += f"Currently discussing: {self.conversation_context['current_employee']}\n"
767
+ if self.conversation_context['last_mentioned_entities']:
768
+ context_memory += f"Recently mentioned: {', '.join(self.conversation_context['last_mentioned_entities'])}\n"
769
+ context_memory += "\n"
770
+
771
+ # Build relevant past conversations (learning from history)
772
+ past_context = ""
773
+ if relevant_past_chats:
774
+ past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
775
+ for i, chat_info in enumerate(relevant_past_chats[:3], 1):
776
+ chat = chat_info['chat']
777
+ past_context += f"\n[Past Q&A {i}]:\n"
778
+ past_context += f"Previous Question: {chat['question']}\n"
779
+ past_context += f"Previous Answer: {chat['answer']}\n"
780
+ past_context += "\n"
781
+
782
+ # CHANGE THIS LINE from [-3:] to [-10:]:
783
+ history_text = ""
784
+ for entry in self.chat_history: # Changed from -3 to -10
785
+ history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
786
+
787
+ prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
788
+
789
+ IMPORTANT INSTRUCTIONS:
790
+ - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
791
+ - For employee-related queries, use the employee data provided
792
+ - If you find any name from user input, always look into the EMPLOYEE TABLE first. If you still can't find, then you can go for chunked text.
793
+ - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in recent conversation
794
+ - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
795
+ - While your answer is related to an employee, be careful of not giving all the information of the employee. Just give the information user asked.
796
+ - For counting or calculations, use the table data
797
+ - For policy questions, use the Q&A knowledge base
798
+ - LEARN from relevant past conversations - if similar questions were asked before, maintain consistency
799
+ - Use patterns from past interactions to improve answer quality
800
+ - Provide specific, accurate answers based on the context
801
+ - If you need to count employees or perform calculations, do it carefully from the data
802
+ - If information is not in the context, just say "I don't have this information in the provided documents"
803
+ - While performing any type of mathematical calculation, always round up any fractional number.
804
+
805
+ Context:
806
+ {context_text}
807
+
808
+ {context_memory}
809
+
810
+ {past_context}
811
+
812
+ Recent conversation:
813
+ {history_text}
814
+
815
+ User Question: {query}
816
+
817
+ Answer based on the context above. Be specific and accurate. But don't always start with "based on the context"[/INST]"""
818
+
819
+ return prompt
820
+
821
+ def ask(self, question: str) -> str:
822
+ """Ask a question to the chatbot with learning from past conversations"""
823
+ if question.lower() in ["reset data", "reset"]:
824
+ self.chat_history = []
825
+ self.chat_embeddings = []
826
+ self.chat_index = None
827
+ self.conversation_context = {'current_employee': None, 'last_mentioned_entities': []} # ADD THIS LINE
828
+ self._save_chat_history()
829
+ return "Chat history has been reset."
830
+
831
+ # ADD THIS LINE:
832
+ resolved_question = self._resolve_pronouns(question)
833
+
834
+ # CHANGE 'question' to 'resolved_question' in next line:
835
+ pattern = self._extract_query_pattern(resolved_question)
836
+ self.query_patterns[pattern] += 1
837
+
838
+ # CHANGE 'question' to 'resolved_question':
839
+ relevant_past_chats = self._search_chat_history(resolved_question, k=10)
840
+
841
+ # CHANGE 'question' to 'resolved_question':
842
+ retrieved_data = self._retrieve(resolved_question, k=20)
843
+
844
+ # CHANGE 'question' to 'resolved_question':
845
+ prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
846
+
847
+ # Generate response
848
+ messages = [{"role": "user", "content": prompt}]
849
+
850
+ response = self.llm_client.chat_completion(
851
+ messages=messages,
852
+ model="meta-llama/Llama-3.1-8B-Instruct",
853
+ max_tokens=512,
854
+ temperature=0.3
855
+ )
856
+
857
+ answer = response.choices[0].message.content
858
+
859
+ # ADD THIS LINE:
860
+ self._update_conversation_context(question, answer)
861
+
862
+ # Store in history with timestamp and metadata
863
+ chat_entry = {
864
+ 'timestamp': datetime.now().isoformat(),
865
+ 'question': question,
866
+ 'answer': answer,
867
+ 'pattern': pattern,
868
+ 'used_past_context': len(relevant_past_chats) > 0
869
+ }
870
+
871
+ self.chat_history.append(chat_entry)
872
+
873
+ # Update chat history index with new conversation
874
+ new_text = f"Q: {question}\nA: {answer}"
875
+ new_embedding = self.embeddings_model.encode([new_text])
876
+
877
+ if self.chat_index is None:
878
+ dimension = new_embedding.shape[1]
879
+ self.chat_index = faiss.IndexFlatL2(dimension)
880
+ self.chat_embeddings = new_embedding
881
+ else:
882
+ self.chat_embeddings = np.vstack([self.chat_embeddings, new_embedding])
883
+
884
+ self.chat_index.add(np.array(new_embedding).astype('float32'))
885
+
886
+ # Save to disk after each conversation
887
+ self._save_chat_history()
888
+ self._save_learning_stats()
889
+
890
+ return answer
891
+
892
+ def provide_feedback(self, question: str, rating: int):
893
+ """Allow user to rate responses for reinforcement learning (1-5 scale)"""
894
+ if 1 <= rating <= 5:
895
+ # Find the most recent occurrence of this question
896
+ for i in range(len(self.chat_history) - 1, -1, -1):
897
+ if self.chat_history[i]['question'] == question:
898
+ chat_id = f"{i}_{self.chat_history[i]['timestamp']}"
899
+ self.feedback_scores[chat_id] = rating
900
+ self._save_learning_stats()
901
+ print(f"Feedback recorded: {rating}/5")
902
+ return
903
+ print("Question not found in recent history")
904
+ else:
905
+ print("Rating must be between 1 and 5")
906
+
907
+ def get_learning_insights(self) -> Dict:
908
+ """Get insights about what the chatbot has learned"""
909
+ total_conversations = len(self.chat_history)
910
+ conversations_with_past_context = sum(
911
+ 1 for c in self.chat_history if c.get('used_past_context', False)
912
+ )
913
+
914
+ avg_feedback = 0
915
+ if self.feedback_scores:
916
+ avg_feedback = sum(self.feedback_scores.values()) / len(self.feedback_scores)
917
+
918
+ return {
919
+ 'total_conversations': total_conversations,
920
+ 'conversations_using_past_context': conversations_with_past_context,
921
+ 'query_patterns': dict(self.query_patterns.most_common(10)),
922
+ 'total_feedback_entries': len(self.feedback_scores),
923
+ 'average_feedback_score': round(avg_feedback, 2)
924
+ }
925
+
926
+ def get_history(self) -> List[Dict]:
927
+ """Get chat history"""
928
+ return self.chat_history
929
+
930
+ def display_stats(self):
931
+ """Display system statistics"""
932
+ qa_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'qa')
933
+ text_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'text')
934
+ table_full = sum(1 for c in self.chunk_metadata if c['type'] == 'table_full')
935
+ table_rows = sum(1 for c in self.chunk_metadata if c['type'] == 'table_row')
936
+
937
+ insights = self.get_learning_insights()
938
+
939
+ print(f"\n{'=' * 80}")
940
+ print("CHATBOT STATISTICS")
941
+ print(f"{'=' * 80}")
942
+ print(f"Total chunks: {len(self.chunks)}")
943
+ print(f" - Q&A chunks: {qa_chunks}")
944
+ print(f" - Text chunks: {text_chunks}")
945
+ print(f" - Full table: {table_full}")
946
+ print(f" - Employee records: {table_rows}")
947
+ print(f"\nLEARNING STATISTICS:")
948
+ print(f" - Total conversations: {insights['total_conversations']}")
949
+ print(f" - Conversations using past context: {insights['conversations_using_past_context']}")
950
+ print(f" - Total feedback entries: {insights['total_feedback_entries']}")
951
+ print(f" - Average feedback score: {insights['average_feedback_score']}/5")
952
+ print(f"\nTop query patterns:")
953
+ for pattern, count in list(insights['query_patterns'].items())[:5]:
954
+ print(f" - {pattern}: {count}")
955
+ print(f"\nOutput directory: {self.output_dir}/")
956
+ print(f"Table CSV: {os.path.basename(self.table_csv_path) if self.table_csv_path else 'None'}")
957
+ print(f"Text chunks: {os.path.basename(self.text_chunks_path)}")
958
+ print(f"History file: {os.path.basename(self.history_file)}")
959
+ print(f"Learning stats: {os.path.basename(self.stats_file)}")
960
+ print(f"{'=' * 80}\n")
961
+
962
+
963
+ # Main execution
964
+ if __name__ == "__main__":
965
+ # Configuration
966
+ PDF_PATH = "data/policies.pdf"
967
+ HF_TOKEN = os.getenv("HF_TOKEN")
968
+
969
+ if not HF_TOKEN:
970
+ raise ValueError("HF_TOKEN environment variable not set")
971
+
972
+ # Initialize chatbot
973
+ print("\nInitializing RAG Chatbot with Learning Capabilities...")
974
+ bot = RAGChatbot(PDF_PATH, HF_TOKEN)
975
+
976
+ # Display statistics
977
+ bot.display_stats()
978
+
979
+ # Chat loop
980
+ print("Chatbot ready! Type 'exit' to quit, 'stats' for learning insights, or 'feedback' to rate last answer.\n")
981
+ last_question = None
982
+
983
+ while True:
984
+ user_input = input("You: ")
985
+
986
+ if user_input.lower() in ['exit', 'quit', 'q']:
987
+ print("Goodbye!")
988
+ break
989
+
990
+ if user_input.lower() == 'stats':
991
+ insights = bot.get_learning_insights()
992
+ print("\nLearning Insights:")
993
+ print(json.dumps(insights, indent=2))
994
+ continue
995
+
996
+ if user_input.lower() == 'feedback':
997
+ if last_question:
998
+ try:
999
+ rating = int(input("Rate the last answer (1-5): "))
1000
+ bot.provide_feedback(last_question, rating)
1001
+ except ValueError:
1002
+ print("Invalid rating")
1003
+ else:
1004
+ print("No previous question to rate")
1005
+ continue
1006
+
1007
+ if not user_input.strip():
1008
+ continue
1009
+
1010
+ try:
1011
+ last_question = user_input
1012
+ answer = bot.ask(user_input)
1013
+ print(f"\nBot: {answer}\n")
1014
+ except Exception as e:
1015
+ print(f"Error: {e}\n")
data/policies.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0b5e498b8231c3c4fd80aeba1fe10b96627c40974e11d0be82b5fe47a83900b
3
+ size 338325
requirements ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
+ PyPDF2==3.0.1
5
+ faiss-cpu==1.7.4
6
+ numpy==1.24.3
7
+ sentence-transformers==2.2.2
8
+ huggingface-hub==0.19.4
9
+ pandas==2.0.3
10
+ tabula-py==2.9.0
11
+ python-multipart==0.0.6