Jaheen07 commited on
Commit
fb8de8a
·
verified ·
1 Parent(s): a16ec11

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +32 -0
  3. README.md +10 -13
  4. app.py +328 -0
  5. chatbot.py +1021 -0
  6. data/policies.pdf +3 -0
  7. requirements.txt +12 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/policies.pdf filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies (Java needed for tabula)
6
+ RUN apt-get update && apt-get install -y \
7
+ default-jre \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements and install
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application files
15
+ COPY app.py .
16
+ COPY chatbot.py .
17
+ COPY data/ ./data/
18
+
19
+ # Create output directory
20
+ RUN mkdir -p /app/output
21
+
22
+ # Expose port
23
+ EXPOSE 7860
24
+
25
+ # Health check
26
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
27
+ CMD python -c "import requests; requests.get('http://localhost:7860/api/health')"
28
+
29
+ ENV PYTHONUNBUFFERED=1
30
+
31
+ # Run application
32
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,13 +1,10 @@
1
- ---
2
- title: Test Bot
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 6.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Rag Chatbot Api
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List, Dict
5
+ import os
6
+ from datetime import datetime
7
+ import logging
8
+ import threading
9
+ import requests
10
+ from chatbot import RAGChatbot
11
+ # from dotenv import load_dotenv
12
+ #
13
+ # load_dotenv()
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ app = FastAPI(
19
+ title="RAG Chatbot API - Multi-User",
20
+ description="HR Assistant Chatbot with Per-User Session Management",
21
+ version="2.0.0",
22
+ docs_url="/docs",
23
+ redoc_url="/redoc"
24
+ )
25
+
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ # Global base chatbot instance
35
+ base_chatbot = None
36
+
37
+ # Per-user session storage
38
+ user_sessions = {}
39
+ session_lock = threading.Lock()
40
+
41
+ # Configuration
42
+ MAX_SESSIONS = 100
43
+ SESSION_TIMEOUT = 3600 # 1 hour
44
+
45
+
46
+ class UserSession:
47
+ """Isolated session for each user"""
48
+
49
+ def __init__(self, user_id: str):
50
+ self.user_id = user_id
51
+ self.chat_history = []
52
+ self.conversation_context = {
53
+ 'current_employee': None,
54
+ 'last_mentioned_entities': []
55
+ }
56
+ self.last_activity = datetime.now()
57
+
58
+ def update_activity(self):
59
+ self.last_activity = datetime.now()
60
+
61
+
62
+ def cleanup_old_sessions():
63
+ """Remove inactive sessions"""
64
+ with session_lock:
65
+ current_time = datetime.now()
66
+ to_remove = []
67
+
68
+ for user_id, session in user_sessions.items():
69
+ time_diff = (current_time - session.last_activity).total_seconds()
70
+ if time_diff > SESSION_TIMEOUT:
71
+ to_remove.append(user_id)
72
+
73
+ for user_id in to_remove:
74
+ del user_sessions[user_id]
75
+ logger.info(f"Cleaned up session for user: {user_id}")
76
+
77
+
78
+ def get_or_create_session(user_id: str) -> UserSession:
79
+ """Get existing session or create new one"""
80
+ with session_lock:
81
+ if len(user_sessions) > MAX_SESSIONS:
82
+ cleanup_old_sessions()
83
+
84
+ if user_id not in user_sessions:
85
+ user_sessions[user_id] = UserSession(user_id)
86
+ logger.info(f"Created new session for user: {user_id}")
87
+
88
+ session = user_sessions[user_id]
89
+ session.update_activity()
90
+ return session
91
+
92
+
93
+ # Pydantic models
94
+ class ChatRequest(BaseModel):
95
+ question: str
96
+ user_id: str
97
+
98
+
99
+ class ChatResponse(BaseModel):
100
+ question: str
101
+ answer: str
102
+ timestamp: str
103
+ user_id: str
104
+ session_info: Dict
105
+
106
+
107
+ @app.on_event("startup")
108
+ async def startup_event():
109
+ global base_chatbot
110
+
111
+ logger.info("=== Starting RAG Chatbot Initialization ===")
112
+
113
+ try:
114
+ PDF_PATH = os.getenv("PDF_PATH", "./data/policies.pdf")
115
+ HF_TOKEN = os.getenv("HF_TOKEN")
116
+
117
+ # if not HF_TOKEN:
118
+ # raise ValueError("HF_TOKEN environment variable not set")
119
+
120
+ logger.info(f"PDF Path: {PDF_PATH}")
121
+ logger.info(f"File exists: {os.path.exists(PDF_PATH)}")
122
+
123
+ if not os.path.exists(PDF_PATH):
124
+ raise ValueError(f"PDF file not found at {PDF_PATH}")
125
+
126
+ base_chatbot = RAGChatbot(PDF_PATH, HF_TOKEN)
127
+ logger.info("=== Base chatbot initialized successfully! ===")
128
+
129
+ except Exception as e:
130
+ logger.error(f"Failed to initialize chatbot: {e}")
131
+ raise
132
+
133
+
134
+ @app.get("/")
135
+ async def root():
136
+ return {
137
+ "service": "RAG Chatbot API",
138
+ "version": "2.0.0",
139
+ "status": "healthy",
140
+ "active_sessions": len(user_sessions),
141
+ "chatbot_loaded": base_chatbot is not None,
142
+ "endpoints": {
143
+ "docs": "/docs",
144
+ "chat": "POST /api/chat",
145
+ "history": "GET /api/history/{user_id}",
146
+ "reset": "POST /api/reset?user_id=xxx",
147
+ "sessions": "GET /api/sessions"
148
+ }
149
+ }
150
+
151
+
152
+ @app.get("/api/health")
153
+ async def health_check():
154
+ if base_chatbot is None:
155
+ raise HTTPException(status_code=503, detail="Chatbot not initialized")
156
+
157
+ return {
158
+ "status": "healthy",
159
+ "timestamp": datetime.now().isoformat(),
160
+ "chatbot_ready": True,
161
+ "active_sessions": len(user_sessions)
162
+ }
163
+
164
+
165
+ @app.post("/api/chat", response_model=ChatResponse)
166
+ async def chat(request: ChatRequest):
167
+ """Send a question to the chatbot with user session isolation"""
168
+ if base_chatbot is None:
169
+ raise HTTPException(status_code=503, detail="Chatbot not initialized")
170
+
171
+ if not request.question.strip():
172
+ raise HTTPException(status_code=400, detail="Question cannot be empty")
173
+
174
+ if not request.user_id:
175
+ raise HTTPException(status_code=400, detail="user_id is required")
176
+
177
+ try:
178
+ logger.info(f"User {request.user_id}: {request.question[:50]}...")
179
+
180
+ # Get user session
181
+ session = get_or_create_session(request.user_id)
182
+
183
+ # Resolve pronouns
184
+ resolved_question = base_chatbot._resolve_pronouns_for_session(
185
+ request.question,
186
+ session.conversation_context
187
+ )
188
+
189
+ # Retrieve relevant chunks
190
+ retrieved_data = base_chatbot._retrieve(resolved_question, k=20)
191
+
192
+ # Search user's chat history
193
+ relevant_past_chats = base_chatbot._search_session_history(
194
+ resolved_question,
195
+ session.chat_history,
196
+ k=5
197
+ )
198
+
199
+ # Build prompt
200
+ prompt = base_chatbot._build_prompt_for_session(
201
+ resolved_question,
202
+ retrieved_data,
203
+ relevant_past_chats,
204
+ session.chat_history,
205
+ session.conversation_context
206
+ )
207
+
208
+ # ✅ NEW: Call Hugging Face Router API
209
+ payload = {
210
+ "model": base_chatbot.model_name,
211
+ "messages": [
212
+ {
213
+ "role": "user",
214
+ "content": prompt
215
+ }
216
+ ],
217
+ "max_tokens": 512,
218
+ "temperature": 0.3
219
+ }
220
+
221
+ response = requests.post(
222
+ base_chatbot.api_url,
223
+ headers=base_chatbot.headers,
224
+ json=payload,
225
+ timeout=60
226
+ )
227
+ response.raise_for_status()
228
+ result = response.json()
229
+
230
+ # Extract answer
231
+ answer = result["choices"][0]["message"]["content"]
232
+
233
+ # Update user's conversation context
234
+ base_chatbot._update_conversation_context_for_session(
235
+ request.question,
236
+ answer,
237
+ session.conversation_context
238
+ )
239
+
240
+ # Store in user's history
241
+ chat_entry = {
242
+ 'timestamp': datetime.now().isoformat(),
243
+ 'question': request.question,
244
+ 'answer': answer,
245
+ 'used_past_context': len(relevant_past_chats) > 0
246
+ }
247
+ session.chat_history.append(chat_entry)
248
+
249
+ response_data = ChatResponse(
250
+ question=request.question,
251
+ answer=answer,
252
+ timestamp=datetime.now().isoformat(),
253
+ user_id=request.user_id,
254
+ session_info={
255
+ 'total_messages': len(session.chat_history),
256
+ 'current_context': session.conversation_context.get('current_employee')
257
+ }
258
+ )
259
+
260
+ logger.info(f"User {request.user_id}: Question processed successfully")
261
+ return response_data
262
+
263
+ except Exception as e:
264
+ logger.error(f"Error for user {request.user_id}: {e}")
265
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
266
+
267
+ @app.post("/api/reset")
268
+ async def reset_chat(user_id: str):
269
+ """Reset chat history for specific user"""
270
+ if not user_id:
271
+ raise HTTPException(status_code=400, detail="user_id is required")
272
+
273
+ with session_lock:
274
+ if user_id in user_sessions:
275
+ del user_sessions[user_id]
276
+ logger.info(f"Reset session for user: {user_id}")
277
+ return {"message": f"Chat history reset for user {user_id}", "status": "success"}
278
+ else:
279
+ return {"message": f"No session found for user {user_id}", "status": "success"}
280
+
281
+
282
+ @app.get("/api/history/{user_id}")
283
+ async def get_history(user_id: str):
284
+ """Get chat history for specific user"""
285
+ session = get_or_create_session(user_id)
286
+
287
+ return {
288
+ "user_id": user_id,
289
+ "total_conversations": len(session.chat_history),
290
+ "current_context": session.conversation_context.get('current_employee'),
291
+ "history": session.chat_history
292
+ }
293
+
294
+
295
+ @app.get("/api/sessions")
296
+ async def get_active_sessions():
297
+ """Get list of active sessions"""
298
+ with session_lock:
299
+ return {
300
+ "total_sessions": len(user_sessions),
301
+ "max_sessions": MAX_SESSIONS,
302
+ "session_timeout_seconds": SESSION_TIMEOUT,
303
+ "sessions": [
304
+ {
305
+ "user_id": user_id,
306
+ "messages": len(session.chat_history),
307
+ "last_activity": session.last_activity.isoformat(),
308
+ "current_context": session.conversation_context.get('current_employee')
309
+ }
310
+ for user_id, session in user_sessions.items()
311
+ ]
312
+ }
313
+
314
+
315
+ @app.post("/api/cleanup")
316
+ async def manual_cleanup():
317
+ """Manually trigger session cleanup"""
318
+ cleanup_old_sessions()
319
+ return {
320
+ "message": "Cleanup completed",
321
+ "active_sessions": len(user_sessions)
322
+ }
323
+
324
+
325
+ if __name__ == "__main__":
326
+ import uvicorn
327
+
328
+ uvicorn.run(app, host="0.0.0.0", port=7860)
chatbot.py ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Chatbot with Separate Table and Text Processing + Reinforcement Learning from Chat History
2
+ import PyPDF2
3
+ import faiss
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from huggingface_hub import InferenceClient
7
+ from typing import List, Tuple, Dict
8
+ import json
9
+ import re
10
+ import pandas as pd
11
+ import tabula.io as tabula
12
+ import os
13
+ import pickle
14
+ from datetime import datetime
15
+ from collections import Counter
16
+ import requests
17
+
18
+
19
+
20
+ class RAGChatbot:
21
+ def __init__(self, pdf_path: str, hf_token: str):
22
+ self.pdf_path = pdf_path
23
+ self.hf_token = hf_token
24
+ self.chunks = []
25
+ self.chunk_metadata = []
26
+ self.index = None
27
+ self.embeddings_model = None
28
+
29
+ # ✅ NEW: API configuration
30
+ self.api_url = "https://router.huggingface.co/v1/chat/completions"
31
+ self.headers = {"Authorization": f"Bearer {hf_token}"}
32
+ self.model_name = "meta-llama/Llama-3.3-70B-Instruct:sambanova"
33
+
34
+ self.chat_history = []
35
+ self.output_dir = "./"
36
+ self.table_csv_path = None
37
+ self.text_chunks_path = None
38
+ self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
39
+
40
+ self.chat_embeddings = []
41
+ self.chat_index = None
42
+ self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
43
+
44
+ self.query_patterns = Counter()
45
+ self.feedback_scores = {}
46
+ self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
47
+
48
+ self.conversation_context = {
49
+ 'current_employee': None,
50
+ 'last_mentioned_entities': []
51
+ }
52
+
53
+ os.makedirs(self.output_dir, exist_ok=True)
54
+
55
+ self._load_chat_history()
56
+ self._load_learning_stats()
57
+
58
+ self._setup()
59
+
60
+ self._build_chat_history_index()
61
+
62
+ def _load_chat_history(self):
63
+ """Load chat history from file if exists"""
64
+ if os.path.exists(self.history_file):
65
+ try:
66
+ with open(self.history_file, 'rb') as f:
67
+ self.chat_history = pickle.load(f)
68
+ print(f"Loaded {len(self.chat_history)} previous conversations")
69
+ except Exception as e:
70
+ print(f"Could not load chat history: {e}")
71
+ self.chat_history = []
72
+ else:
73
+ self.chat_history = []
74
+
75
+ def _save_chat_history(self):
76
+ """Save chat history to file"""
77
+ try:
78
+ with open(self.history_file, 'wb') as f:
79
+ pickle.dump(self.chat_history, f)
80
+ except Exception as e:
81
+ print(f"Could not save chat history: {e}")
82
+
83
+ def _load_learning_stats(self):
84
+ """Load learning statistics"""
85
+ if os.path.exists(self.stats_file):
86
+ try:
87
+ with open(self.stats_file, 'rb') as f:
88
+ data = pickle.load(f)
89
+ self.query_patterns = data.get('query_patterns', Counter())
90
+ self.feedback_scores = data.get('feedback_scores', {})
91
+ print(f"Loaded learning statistics: {len(self.query_patterns)} patterns tracked")
92
+ except Exception as e:
93
+ print(f"Could not load learning stats: {e}")
94
+ self.query_patterns = Counter()
95
+ self.feedback_scores = {}
96
+ else:
97
+ self.query_patterns = Counter()
98
+ self.feedback_scores = {}
99
+
100
+ def _save_learning_stats(self):
101
+ """Save learning statistics"""
102
+ try:
103
+ with open(self.stats_file, 'wb') as f:
104
+ pickle.dump({
105
+ 'query_patterns': self.query_patterns,
106
+ 'feedback_scores': self.feedback_scores
107
+ }, f)
108
+ except Exception as e:
109
+ print(f"Could not save learning stats: {e}")
110
+
111
+ def _build_chat_history_index(self):
112
+ """Build FAISS index from chat history for semantic search"""
113
+ if len(self.chat_history) == 0:
114
+ print("No chat history to index")
115
+ return
116
+
117
+ print(f"Building semantic index for {len(self.chat_history)} past conversations...")
118
+
119
+ # Create embeddings for all past Q&A pairs
120
+ chat_texts = []
121
+ for entry in self.chat_history:
122
+ # Combine question and answer for better context
123
+ combined_text = f"Q: {entry['question']}\nA: {entry['answer']}"
124
+ chat_texts.append(combined_text)
125
+
126
+ # Generate embeddings
127
+ self.chat_embeddings = self.embeddings_model.encode(chat_texts, show_progress_bar=True)
128
+
129
+ # Build FAISS index
130
+ dimension = self.chat_embeddings.shape[1]
131
+ self.chat_index = faiss.IndexFlatL2(dimension)
132
+ self.chat_index.add(np.array(self.chat_embeddings).astype('float32'))
133
+
134
+ # Save embeddings
135
+ try:
136
+ with open(self.chat_embedding_file, 'wb') as f:
137
+ pickle.dump(self.chat_embeddings, f)
138
+ except Exception as e:
139
+ print(f"Could not save chat embeddings: {e}")
140
+
141
+ print(f"Chat history index built successfully")
142
+
143
+ def _search_chat_history(self, query: str, k: int = 5) -> List[Dict]:
144
+ """Search through past conversations semantically"""
145
+ if self.chat_index is None or len(self.chat_history) == 0:
146
+ return []
147
+
148
+ # Encode query
149
+ query_embedding = self.embeddings_model.encode([query])
150
+
151
+ # Search
152
+ distances, indices = self.chat_index.search(
153
+ np.array(query_embedding).astype('float32'),
154
+ min(k, len(self.chat_history))
155
+ )
156
+
157
+ # Return relevant past conversations
158
+ relevant_chats = []
159
+ for idx, distance in zip(indices[0], distances[0]):
160
+ if distance < 1.5: # Similarity threshold
161
+ relevant_chats.append({
162
+ 'chat': self.chat_history[idx],
163
+ 'similarity_score': float(distance)
164
+ })
165
+
166
+ return relevant_chats
167
+
168
+ def _extract_entities_from_query(self, query: str) -> Dict:
169
+ """Extract names and entities from query"""
170
+ query_lower = query.lower()
171
+
172
+ # Check for pronouns that need context
173
+ has_pronoun = bool(re.search(r'\b(his|her|their|he|she|they|him|them)\b', query_lower))
174
+
175
+ # Try to extract names (capitalize words that might be names)
176
+ potential_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', query)
177
+
178
+ return {
179
+ 'has_pronoun': has_pronoun,
180
+ 'names': potential_names
181
+ }
182
+
183
+ def _update_conversation_context(self, question: str, answer: str):
184
+ """Update context tracking based on conversation"""
185
+ # Extract names from question
186
+ names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
187
+
188
+ # Extract names from answer
189
+ answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
190
+
191
+ # Update current employee if employee was mentioned
192
+ if 'employee' in answer.lower() or 'working' in answer.lower():
193
+ all_names = names + answer_names
194
+ if all_names:
195
+ self.conversation_context['current_employee'] = all_names[0]
196
+ # Keep last 5 mentioned entities
197
+ self.conversation_context['last_mentioned_entities'] = (
198
+ all_names[:5] if len(all_names) <= 5
199
+ else self.conversation_context['last_mentioned_entities'][-4:] + [all_names[0]]
200
+ )
201
+
202
+ def _resolve_pronouns(self, query: str) -> str:
203
+ """Replace pronouns with actual entity names from context"""
204
+ entities = self._extract_entities_from_query(query)
205
+
206
+ if entities['has_pronoun'] and self.conversation_context['current_employee']:
207
+ current_name = self.conversation_context['current_employee']
208
+
209
+ # Replace pronouns with the current employee name
210
+ query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
211
+ query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
212
+ query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
213
+ query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
214
+
215
+ return query
216
+
217
+
218
+ def _extract_query_pattern(self, query: str) -> str:
219
+ """Extract pattern from query for learning"""
220
+ query_lower = query.lower()
221
+
222
+ # Detect common patterns
223
+ patterns = []
224
+
225
+ if re.search(r'\bhow many\b', query_lower):
226
+ patterns.append('count_query')
227
+ if re.search(r'\bwho\b', query_lower):
228
+ patterns.append('who_query')
229
+ if re.search(r'\bwhat\b', query_lower):
230
+ patterns.append('what_query')
231
+ if re.search(r'\bwhen\b', query_lower):
232
+ patterns.append('when_query')
233
+ if re.search(r'\bwhere\b', query_lower):
234
+ patterns.append('where_query')
235
+ if re.search(r'\blist\b|\ball\b', query_lower):
236
+ patterns.append('list_query')
237
+ if re.search(r'\bcalculate\b|\bsum\b|\btotal\b|\baverage\b', query_lower):
238
+ patterns.append('calculation_query')
239
+ if re.search(r'\bemployee\b|\bstaff\b|\bworker\b', query_lower):
240
+ patterns.append('employee_query')
241
+ if re.search(r'\bpolicy\b|\brule\b|\bguideline\b', query_lower):
242
+ patterns.append('policy_query')
243
+
244
+ return '|'.join(patterns) if patterns else 'general_query'
245
+
246
+ def _load_pdf_text(self) -> str:
247
+ """Load text from PDF"""
248
+ text = ""
249
+ with open(self.pdf_path, 'rb') as file:
250
+ pdf_reader = PyPDF2.PdfReader(file)
251
+ for page in pdf_reader.pages:
252
+ text += page.extract_text()
253
+ return text
254
+
255
+ def _extract_and_merge_tables(self) -> str:
256
+ """Extract all tables from PDF and merge into single CSV"""
257
+ try:
258
+ print("Extracting tables from PDF...")
259
+
260
+ # Extract all tables
261
+ dfs = tabula.read_pdf(self.pdf_path, pages="all", multiple_tables=True)
262
+
263
+ if not dfs or len(dfs) == 0:
264
+ print("No tables found in PDF")
265
+ return None
266
+
267
+ print(f"Found {len(dfs)} tables")
268
+
269
+ # The first table has headers
270
+ merged_df = dfs[0]
271
+
272
+ # Append rest of the tables
273
+ for i in range(1, len(dfs)):
274
+ # Set the column names to match the first table
275
+ dfs[i].columns = merged_df.columns
276
+ # Append rows
277
+ merged_df = pd.concat([merged_df, dfs[i]], ignore_index=True)
278
+
279
+ # Save merged table
280
+ csv_path = os.path.join(self.output_dir, "merged_employee_tables.csv")
281
+ merged_df.to_csv(csv_path, index=False)
282
+
283
+ print(f"Merged {len(dfs)} tables into {csv_path}")
284
+ print(f"Total rows: {len(merged_df)}")
285
+ print(f"Columns: {list(merged_df.columns)}")
286
+
287
+ return csv_path
288
+
289
+ except Exception as e:
290
+ print(f"Table extraction failed: {e}")
291
+ return None
292
+
293
+ def _save_table_chunks(self, table_chunks: List[Dict]) -> str:
294
+ """Save table chunks (full table + row chunks) to a text file"""
295
+ save_path = os.path.join(self.output_dir, "table_chunks.txt")
296
+
297
+ with open(save_path, 'w', encoding='utf-8') as f:
298
+ f.write(f"Total Table Chunks: {len(table_chunks)}\n")
299
+ f.write("=" * 80 + "\n\n")
300
+
301
+ for i, chunk in enumerate(table_chunks):
302
+ f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
303
+ f.write("-" * 80 + "\n")
304
+ f.write(chunk['content'])
305
+ f.write("\n\n" + "=" * 80 + "\n\n")
306
+
307
+ print(f"Saved {len(table_chunks)} table chunks to {save_path}")
308
+ return save_path
309
+
310
+ def _detect_table_regions_in_text(self, text: str) -> List[Tuple[int, int]]:
311
+ """Detect start and end positions of table regions in text"""
312
+ lines = text.split('\n')
313
+ table_regions = []
314
+ start_idx = None
315
+
316
+ for i, line in enumerate(lines):
317
+ is_table_line = (
318
+ '@' in line or
319
+ re.search(r'\b(A|B|AB|O)[+-]\b', line) or
320
+ re.search(r'\s{3,}', line) or
321
+ re.search(r'Employee Name|Email|Position|Table|Blood Group', line, re.IGNORECASE)
322
+ )
323
+
324
+ if is_table_line:
325
+ if start_idx is None:
326
+ start_idx = i
327
+ else:
328
+ if start_idx is not None:
329
+ # End of table region
330
+ if i - start_idx > 3: # Only consider tables with 3+ lines
331
+ table_regions.append((start_idx, i))
332
+ start_idx = None
333
+
334
+ # Handle last table if exists
335
+ if start_idx is not None and len(lines) - start_idx > 3:
336
+ table_regions.append((start_idx, len(lines)))
337
+
338
+ return table_regions
339
+
340
+ def _remove_table_text(self, text: str) -> str:
341
+ """Remove table content from text"""
342
+ lines = text.split('\n')
343
+ table_regions = self._detect_table_regions_in_text(text)
344
+
345
+ # Create set of line indices to remove
346
+ lines_to_remove = set()
347
+ for start, end in table_regions:
348
+ for i in range(start, end):
349
+ lines_to_remove.add(i)
350
+
351
+ # Keep only non-table lines
352
+ clean_lines = [line for i, line in enumerate(lines) if i not in lines_to_remove]
353
+
354
+ return '\n'.join(clean_lines)
355
+
356
+ def _chunk_text_content(self, text: str) -> List[Dict]:
357
+ """Chunk text content (Q&A pairs and other text)"""
358
+ chunks = []
359
+
360
+ # Remove table text
361
+ clean_text = self._remove_table_text(text)
362
+
363
+ # Split by ###Question###
364
+ qa_pairs = clean_text.split('###Question###')
365
+
366
+ for i, qa in enumerate(qa_pairs):
367
+ if not qa.strip():
368
+ continue
369
+
370
+ if '###Answer###' in qa:
371
+ chunk_text = '###Question###' + qa
372
+ if len(chunk_text) > 50:
373
+ chunks.append({
374
+ 'content': chunk_text,
375
+ 'type': 'qa',
376
+ 'source': 'text_content',
377
+ 'chunk_id': f'qa_{i}'
378
+ })
379
+
380
+ # Also create chunks from sections (for non-Q&A content)
381
+ sections = re.split(r'\n\n+', clean_text)
382
+ for i, section in enumerate(sections):
383
+ section = section.strip()
384
+ if len(section) > 200 and '###Question###' not in section:
385
+ chunks.append({
386
+ 'content': section,
387
+ 'type': 'text',
388
+ 'source': 'text_content',
389
+ 'chunk_id': f'text_{i}'
390
+ })
391
+
392
+ return chunks
393
+
394
+ def _save_text_chunks(self, chunks: List[Dict]) -> str:
395
+ """Save text chunks to file"""
396
+ text_path = os.path.join(self.output_dir, "text_chunks.txt")
397
+
398
+ with open(text_path, 'w', encoding='utf-8') as f:
399
+ f.write(f"Total Text Chunks: {len(chunks)}\n")
400
+ f.write("=" * 80 + "\n\n")
401
+
402
+ for i, chunk in enumerate(chunks):
403
+ f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
404
+ f.write("-" * 80 + "\n")
405
+ f.write(chunk['content'])
406
+ f.write("\n\n" + "=" * 80 + "\n\n")
407
+
408
+ print(f"Saved {len(chunks)} text chunks to {text_path}")
409
+ return text_path
410
+
411
+ def _load_csv_as_text(self, csv_path: str) -> str:
412
+ """Load CSV and convert to readable text format"""
413
+ try:
414
+ df = pd.read_csv(csv_path)
415
+ text = f"[EMPLOYEE TABLE DATA]\n"
416
+ text += f"Total Employees: {len(df)}\n\n"
417
+ text += df.to_string(index=False)
418
+ return text
419
+ except Exception as e:
420
+ print(f"Error loading CSV: {e}")
421
+ return ""
422
+
423
+ def _create_table_chunks(self, csv_path: str) -> List[Dict]:
424
+ """Create chunks from CSV table"""
425
+ chunks = []
426
+
427
+ try:
428
+ df = pd.read_csv(csv_path)
429
+
430
+ # Create one chunk with full table overview
431
+ full_table_text = f"[COMPLETE EMPLOYEE TABLE]\n"
432
+ full_table_text += f"Total Employees: {len(df)}\n"
433
+ full_table_text += f"Columns: {', '.join(df.columns)}\n\n"
434
+ full_table_text += df.to_string(index=False)
435
+
436
+ chunks.append({
437
+ 'content': full_table_text,
438
+ 'type': 'table_full',
439
+ 'source': 'employee_table.csv',
440
+ 'chunk_id': 'table_full'
441
+ })
442
+
443
+ # Create chunks for each row (employee)
444
+ for idx, row in df.iterrows():
445
+ row_text = f"[EMPLOYEE RECORD {idx + 1}]\n"
446
+ for col in df.columns:
447
+ row_text += f"{col}: {row[col]}\n"
448
+
449
+ chunks.append({
450
+ 'content': row_text,
451
+ 'type': 'table_row',
452
+ 'source': 'employee_table.csv',
453
+ 'chunk_id': f'employee_{idx}'
454
+ })
455
+
456
+ print(f"Created {len(chunks)} chunks from table ({len(df)} employee records + 1 full table)")
457
+
458
+ except Exception as e:
459
+ print(f"Error creating table chunks: {e}")
460
+
461
+ return chunks
462
+
463
+ def _save_manifest(self, all_chunks: List[Dict]):
464
+ """Save manifest of all chunks"""
465
+ manifest = {
466
+ 'total_chunks': len(all_chunks),
467
+ 'chunks_by_type': {
468
+ 'qa': sum(1 for c in all_chunks if c['type'] == 'qa'),
469
+ 'text': sum(1 for c in all_chunks if c['type'] == 'text'),
470
+ 'table_full': sum(1 for c in all_chunks if c['type'] == 'table_full'),
471
+ 'table_row': sum(1 for c in all_chunks if c['type'] == 'table_row')
472
+ },
473
+ 'files_created': {
474
+ 'table_csv': self.table_csv_path,
475
+ 'text_chunks': self.text_chunks_path
476
+ },
477
+ 'chunk_details': [
478
+ {
479
+ 'chunk_id': c['chunk_id'],
480
+ 'type': c['type'],
481
+ 'source': c['source'],
482
+ 'length': len(c['content'])
483
+ }
484
+ for c in all_chunks
485
+ ]
486
+ }
487
+
488
+ manifest_path = os.path.join(self.output_dir, 'chunk_manifest.json')
489
+ with open(manifest_path, 'w', encoding='utf-8') as f:
490
+ json.dump(manifest, f, indent=2, ensure_ascii=False)
491
+
492
+ print(f"Saved manifest to {manifest_path}")
493
+ return manifest_path
494
+
495
+ def _resolve_pronouns_for_session(self, query: str, conversation_context: Dict) -> str:
496
+ """Resolve pronouns using session-specific context"""
497
+ entities = self._extract_entities_from_query(query)
498
+
499
+ if entities['has_pronoun'] and conversation_context.get('current_employee'):
500
+ current_name = conversation_context['current_employee']
501
+
502
+ query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
503
+ query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
504
+ query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
505
+ query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
506
+
507
+ return query
508
+
509
+ def _search_session_history(self, query: str, session_history: List[Dict], k: int = 5) -> List[Dict]:
510
+ """Search through session-specific history"""
511
+ if not session_history:
512
+ return []
513
+
514
+ chat_texts = [f"Q: {entry['question']}\nA: {entry['answer']}" for entry in session_history]
515
+
516
+ if not chat_texts:
517
+ return []
518
+
519
+ chat_embeddings = self.embeddings_model.encode(chat_texts)
520
+
521
+ dimension = chat_embeddings.shape[1]
522
+ temp_index = faiss.IndexFlatL2(dimension)
523
+ temp_index.add(np.array(chat_embeddings).astype('float32'))
524
+
525
+ query_embedding = self.embeddings_model.encode([query])
526
+ distances, indices = temp_index.search(
527
+ np.array(query_embedding).astype('float32'),
528
+ min(k, len(session_history))
529
+ )
530
+
531
+ relevant_chats = []
532
+ for idx, distance in zip(indices[0], distances[0]):
533
+ if distance < 1.5:
534
+ relevant_chats.append({
535
+ 'chat': session_history[idx],
536
+ 'similarity_score': float(distance)
537
+ })
538
+
539
+ return relevant_chats
540
+
541
+ def _build_prompt_for_session(self, query: str, retrieved_data: List[Tuple[str, Dict]],
542
+ relevant_past_chats: List[Dict], session_history: List[Dict],
543
+ conversation_context: Dict) -> str:
544
+ """Build prompt using session-specific data"""
545
+
546
+ employee_records = []
547
+ full_table = []
548
+ qa_context = []
549
+ text_context = []
550
+
551
+ for content, metadata in retrieved_data:
552
+ if metadata['type'] == 'table_row':
553
+ employee_records.append(content)
554
+ elif metadata['type'] == 'table_full':
555
+ full_table.append(content)
556
+ elif metadata['type'] == 'qa':
557
+ qa_context.append(content)
558
+ elif metadata['type'] == 'text':
559
+ text_context.append(content)
560
+
561
+ context_text = ""
562
+ if full_table:
563
+ context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
564
+ if employee_records:
565
+ context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
566
+ if qa_context:
567
+ context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
568
+ if text_context:
569
+ context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
570
+
571
+ context_memory = ""
572
+ if conversation_context.get('current_employee'):
573
+ context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
574
+ context_memory += f"Currently discussing: {conversation_context['current_employee']}\n"
575
+ if conversation_context.get('last_mentioned_entities'):
576
+ context_memory += f"Recently mentioned: {', '.join(conversation_context['last_mentioned_entities'])}\n"
577
+ context_memory += "\n"
578
+
579
+ past_context = ""
580
+ if relevant_past_chats:
581
+ past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
582
+ for i, chat_info in enumerate(relevant_past_chats[:3], 1):
583
+ chat = chat_info['chat']
584
+ past_context += f"\n[Past Q&A {i}]:\n"
585
+ past_context += f"Previous Question: {chat['question']}\n"
586
+ past_context += f"Previous Answer: {chat['answer']}\n"
587
+ past_context += "\n"
588
+
589
+ history_text = ""
590
+ for entry in session_history[-10:]:
591
+ history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
592
+
593
+ prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
594
+
595
+ IMPORTANT INSTRUCTIONS:
596
+ - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
597
+ - For employee-related queries, use the employee data provided
598
+ - If you find any name from user input, always look into the EMPLOYEE TABLE first
599
+ - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in THIS USER's recent conversation
600
+ - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
601
+ - Be careful not to give all employee information - only answer what was asked
602
+ - For counting or calculations, use the table data
603
+ - For policy questions, use the Q&A knowledge base
604
+ - Provide specific, accurate answers based on the context
605
+ - If information is not in the context, say "I don't have this information"
606
+ - Round up any fractional numbers in calculations
607
+
608
+ Context:
609
+ {context_text}
610
+
611
+ {context_memory}
612
+
613
+ {past_context}
614
+
615
+ Recent conversation:
616
+ {history_text}
617
+
618
+ User Question: {query}
619
+
620
+ Answer based on the context above. Be specific and accurate.[/INST]"""
621
+
622
+ return prompt
623
+
624
+ def _update_conversation_context_for_session(self, question: str, answer: str, conversation_context: Dict):
625
+ """Update session-specific conversation context"""
626
+ names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
627
+ answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
628
+
629
+ if 'employee' in answer.lower() or 'working' in answer.lower():
630
+ all_names = names + answer_names
631
+ if all_names:
632
+ conversation_context['current_employee'] = all_names[0]
633
+ conversation_context['last_mentioned_entities'] = (
634
+ all_names[:5] if len(all_names) <= 5
635
+ else conversation_context.get('last_mentioned_entities', [])[-4:] + [all_names[0]]
636
+ )
637
+
638
+ def _setup(self):
639
+ print("\n" + "=" * 80)
640
+ print("STEP 1: Loading PDF")
641
+ print("=" * 80)
642
+
643
+ text = self._load_pdf_text()
644
+ print(f"Loaded PDF with {len(text)} characters")
645
+
646
+ print("\n" + "=" * 80)
647
+ print("STEP 2: Extracting and Merging Tables")
648
+ print("=" * 80)
649
+
650
+ self.table_csv_path = self._extract_and_merge_tables()
651
+
652
+ print("\n" + "=" * 80)
653
+ print("STEP 3: Chunking Text Content (Removing Tables)")
654
+ print("=" * 80)
655
+
656
+ text_chunks = self._chunk_text_content(text)
657
+ self.text_chunks_path = self._save_text_chunks(text_chunks)
658
+
659
+ print("\n" + "=" * 80)
660
+ print("STEP 4: Creating Final Chunks")
661
+ print("=" * 80)
662
+
663
+ all_chunks = []
664
+ all_chunks.extend(text_chunks)
665
+
666
+ if self.table_csv_path:
667
+ table_chunks = self._create_table_chunks(self.table_csv_path)
668
+ all_chunks.extend(table_chunks)
669
+ self._save_table_chunks(table_chunks)
670
+
671
+ self.chunks = [c['content'] for c in all_chunks]
672
+ self.chunk_metadata = all_chunks
673
+
674
+ print(f"\nTotal chunks created: {len(self.chunks)}")
675
+ print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
676
+ print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
677
+ print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
678
+ print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
679
+
680
+ self._save_manifest(all_chunks)
681
+
682
+ print("\n" + "=" * 80)
683
+ print("STEP 5: Creating Embeddings")
684
+ print("=" * 80)
685
+
686
+ print("Loading embedding model...")
687
+ self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
688
+
689
+ print("Creating embeddings for all chunks...")
690
+ embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
691
+
692
+ print("Building FAISS index...")
693
+ dimension = embeddings.shape[1]
694
+ self.index = faiss.IndexFlatL2(dimension)
695
+ self.index.add(np.array(embeddings).astype('float32'))
696
+
697
+ print("\n" + "=" * 80)
698
+ print("STEP 6: Initializing LLM API")
699
+ print("=" * 80)
700
+
701
+ # ✅ API already configured in __init__
702
+ print(f"API URL: {self.api_url}")
703
+ print(f"Model: {self.model_name}")
704
+ print("LLM API ready!")
705
+
706
+ print("\n" + "=" * 80)
707
+ print("SETUP COMPLETE!")
708
+ print("=" * 80)
709
+
710
+ def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
711
+ """Retrieve relevant chunks with metadata"""
712
+ query_embedding = self.embeddings_model.encode([query])
713
+ distances, indices = self.index.search(np.array(query_embedding).astype('float32'), k)
714
+
715
+ results = []
716
+ for idx in indices[0]:
717
+ results.append((self.chunks[idx], self.chunk_metadata[idx]))
718
+
719
+ return results
720
+
721
+ def _build_prompt(self, query: str, retrieved_data: List[Tuple[str, Dict]], relevant_past_chats: List[Dict]) -> str:
722
+ """Build prompt with retrieved context and learned information from past chats"""
723
+
724
+ # Separate different types of context
725
+ employee_records = []
726
+ full_table = []
727
+ qa_context = []
728
+ text_context = []
729
+
730
+ for content, metadata in retrieved_data:
731
+ if metadata['type'] == 'table_row':
732
+ employee_records.append(content)
733
+ elif metadata['type'] == 'table_full':
734
+ full_table.append(content)
735
+ elif metadata['type'] == 'qa':
736
+ qa_context.append(content)
737
+ elif metadata['type'] == 'text':
738
+ text_context.append(content)
739
+
740
+ # Build context sections
741
+ context_text = ""
742
+
743
+ if full_table:
744
+ context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
745
+
746
+ if employee_records:
747
+ context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
748
+
749
+ if qa_context:
750
+ context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
751
+
752
+ if text_context:
753
+ context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
754
+
755
+ # ADD THIS NEW SECTION:
756
+ context_memory = ""
757
+ if self.conversation_context['current_employee']:
758
+ context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
759
+ context_memory += f"Currently discussing: {self.conversation_context['current_employee']}\n"
760
+ if self.conversation_context['last_mentioned_entities']:
761
+ context_memory += f"Recently mentioned: {', '.join(self.conversation_context['last_mentioned_entities'])}\n"
762
+ context_memory += "\n"
763
+
764
+ # Build relevant past conversations (learning from history)
765
+ past_context = ""
766
+ if relevant_past_chats:
767
+ past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
768
+ for i, chat_info in enumerate(relevant_past_chats[:3], 1):
769
+ chat = chat_info['chat']
770
+ past_context += f"\n[Past Q&A {i}]:\n"
771
+ past_context += f"Previous Question: {chat['question']}\n"
772
+ past_context += f"Previous Answer: {chat['answer']}\n"
773
+ past_context += "\n"
774
+
775
+ # CHANGE THIS LINE from [-3:] to [-10:]:
776
+ history_text = ""
777
+ for entry in self.chat_history: # Changed from -3 to -10
778
+ history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
779
+
780
+ prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
781
+
782
+ IMPORTANT INSTRUCTIONS:
783
+ - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
784
+ - For employee-related queries, use the employee data provided
785
+ - If you find any name from user input, always look into the EMPLOYEE TABLE first. If you still can't find, then you can go for chunked text.
786
+ - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in recent conversation
787
+ - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
788
+ - While your answer is related to an employee, be careful of not giving all the information of the employee. Just give the information user asked.
789
+ - For counting or calculations, use the table data
790
+ - For policy questions, use the Q&A knowledge base
791
+ - LEARN from relevant past conversations - if similar questions were asked before, maintain consistency
792
+ - Use patterns from past interactions to improve answer quality
793
+ - Provide specific, accurate answers based on the context
794
+ - If you need to count employees or perform calculations, do it carefully from the data
795
+ - If information is not in the context, just say "I don't have this information in the provided documents"
796
+ - While performing any type of mathematical calculation, always round up any fractional number.
797
+
798
+ Context:
799
+ {context_text}
800
+
801
+ {context_memory}
802
+
803
+ {past_context}
804
+
805
+ Recent conversation:
806
+ {history_text}
807
+
808
+ User Question: {query}
809
+
810
+ Answer based on the context above. Be specific and accurate. But don't always start with "based on the context"[/INST]"""
811
+
812
+ return prompt
813
+
814
+ def ask(self, question: str) -> str:
815
+ """Ask a question to the chatbot with learning from past conversations"""
816
+ if question.lower() in ["reset data", "reset"]:
817
+ self.chat_history = []
818
+ self.chat_embeddings = []
819
+ self.chat_index = None
820
+ self.conversation_context = {'current_employee': None, 'last_mentioned_entities': []}
821
+ self._save_chat_history()
822
+ return "Chat history has been reset."
823
+
824
+ # Resolve pronouns before processing
825
+ resolved_question = self._resolve_pronouns(question)
826
+
827
+ # Extract query pattern for learning
828
+ pattern = self._extract_query_pattern(resolved_question)
829
+ self.query_patterns[pattern] += 1
830
+
831
+ # Search through past conversations for similar questions
832
+ relevant_past_chats = self._search_chat_history(resolved_question, k=5)
833
+
834
+ # Retrieve relevant chunks
835
+ retrieved_data = self._retrieve(resolved_question, k=20)
836
+
837
+ # Build prompt
838
+ prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
839
+
840
+ # ✅ NEW: Call Hugging Face Router API
841
+ payload = {
842
+ "model": self.model_name,
843
+ "messages": [
844
+ {
845
+ "role": "user",
846
+ "content": prompt
847
+ }
848
+ ],
849
+ "max_tokens": 512,
850
+ "temperature": 0.3
851
+ }
852
+
853
+ try:
854
+ response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=60)
855
+ response.raise_for_status()
856
+ result = response.json()
857
+
858
+ # Extract answer from response
859
+ answer = result["choices"][0]["message"]["content"]
860
+
861
+ except Exception as e:
862
+ print(f"Error calling LLM API: {e}")
863
+ answer = "I apologize, but I'm having trouble generating a response right now. Please try again."
864
+
865
+ # Update conversation context
866
+ self._update_conversation_context(question, answer)
867
+
868
+ # Store in history
869
+ chat_entry = {
870
+ 'timestamp': datetime.now().isoformat(),
871
+ 'question': question,
872
+ 'answer': answer,
873
+ 'pattern': pattern,
874
+ 'used_past_context': len(relevant_past_chats) > 0
875
+ }
876
+
877
+ self.chat_history.append(chat_entry)
878
+
879
+ # Update chat history index
880
+ new_text = f"Q: {question}\nA: {answer}"
881
+ new_embedding = self.embeddings_model.encode([new_text])
882
+
883
+ if self.chat_index is None:
884
+ dimension = new_embedding.shape[1]
885
+ self.chat_index = faiss.IndexFlatL2(dimension)
886
+ self.chat_embeddings = new_embedding
887
+ else:
888
+ self.chat_embeddings = np.vstack([self.chat_embeddings, new_embedding])
889
+
890
+ self.chat_index.add(np.array(new_embedding).astype('float32'))
891
+
892
+ # Save to disk
893
+ self._save_chat_history()
894
+ self._save_learning_stats()
895
+
896
+ return answer
897
+
898
+ def provide_feedback(self, question: str, rating: int):
899
+ """Allow user to rate responses for reinforcement learning (1-5 scale)"""
900
+ if 1 <= rating <= 5:
901
+ # Find the most recent occurrence of this question
902
+ for i in range(len(self.chat_history) - 1, -1, -1):
903
+ if self.chat_history[i]['question'] == question:
904
+ chat_id = f"{i}_{self.chat_history[i]['timestamp']}"
905
+ self.feedback_scores[chat_id] = rating
906
+ self._save_learning_stats()
907
+ print(f"Feedback recorded: {rating}/5")
908
+ return
909
+ print("Question not found in recent history")
910
+ else:
911
+ print("Rating must be between 1 and 5")
912
+
913
+ def get_learning_insights(self) -> Dict:
914
+ """Get insights about what the chatbot has learned"""
915
+ total_conversations = len(self.chat_history)
916
+ conversations_with_past_context = sum(
917
+ 1 for c in self.chat_history if c.get('used_past_context', False)
918
+ )
919
+
920
+ avg_feedback = 0
921
+ if self.feedback_scores:
922
+ avg_feedback = sum(self.feedback_scores.values()) / len(self.feedback_scores)
923
+
924
+ return {
925
+ 'total_conversations': total_conversations,
926
+ 'conversations_using_past_context': conversations_with_past_context,
927
+ 'query_patterns': dict(self.query_patterns.most_common(10)),
928
+ 'total_feedback_entries': len(self.feedback_scores),
929
+ 'average_feedback_score': round(avg_feedback, 2)
930
+ }
931
+
932
+ def get_history(self) -> List[Dict]:
933
+ """Get chat history"""
934
+ return self.chat_history
935
+
936
+ def display_stats(self):
937
+ """Display system statistics"""
938
+ qa_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'qa')
939
+ text_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'text')
940
+ table_full = sum(1 for c in self.chunk_metadata if c['type'] == 'table_full')
941
+ table_rows = sum(1 for c in self.chunk_metadata if c['type'] == 'table_row')
942
+
943
+ insights = self.get_learning_insights()
944
+
945
+ print(f"\n{'=' * 80}")
946
+ print("CHATBOT STATISTICS")
947
+ print(f"{'=' * 80}")
948
+ print(f"Total chunks: {len(self.chunks)}")
949
+ print(f" - Q&A chunks: {qa_chunks}")
950
+ print(f" - Text chunks: {text_chunks}")
951
+ print(f" - Full table: {table_full}")
952
+ print(f" - Employee records: {table_rows}")
953
+ print(f"\nLEARNING STATISTICS:")
954
+ print(f" - Total conversations: {insights['total_conversations']}")
955
+ print(f" - Conversations using past context: {insights['conversations_using_past_context']}")
956
+ print(f" - Total feedback entries: {insights['total_feedback_entries']}")
957
+ print(f" - Average feedback score: {insights['average_feedback_score']}/5")
958
+ print(f"\nTop query patterns:")
959
+ for pattern, count in list(insights['query_patterns'].items())[:5]:
960
+ print(f" - {pattern}: {count}")
961
+ print(f"\nOutput directory: {self.output_dir}/")
962
+ print(f"Table CSV: {os.path.basename(self.table_csv_path) if self.table_csv_path else 'None'}")
963
+ print(f"Text chunks: {os.path.basename(self.text_chunks_path)}")
964
+ print(f"History file: {os.path.basename(self.history_file)}")
965
+ print(f"Learning stats: {os.path.basename(self.stats_file)}")
966
+ print(f"{'=' * 80}\n")
967
+
968
+
969
+ # Main execution
970
+ if __name__ == "__main__":
971
+ # Configuration
972
+ PDF_PATH = "./data/policies.pdf"
973
+ HF_TOKEN = os.getenv("HF_TOKEN")
974
+
975
+ if not HF_TOKEN:
976
+ raise ValueError("HF_TOKEN environment variable not set")
977
+
978
+ # Initialize chatbot
979
+ print("\nInitializing RAG Chatbot with Learning Capabilities...")
980
+ bot = RAGChatbot(PDF_PATH, HF_TOKEN)
981
+
982
+ # Display statistics
983
+ bot.display_stats()
984
+
985
+ # Chat loop
986
+ print("Chatbot ready! Type 'exit' to quit, 'stats' for learning insights, or 'feedback' to rate last answer.\n")
987
+ last_question = None
988
+
989
+ while True:
990
+ user_input = input("You: ")
991
+
992
+ if user_input.lower() in ['exit', 'quit', 'q']:
993
+ print("Goodbye!")
994
+ break
995
+
996
+ if user_input.lower() == 'stats':
997
+ insights = bot.get_learning_insights()
998
+ print("\nLearning Insights:")
999
+ print(json.dumps(insights, indent=2))
1000
+ continue
1001
+
1002
+ if user_input.lower() == 'feedback':
1003
+ if last_question:
1004
+ try:
1005
+ rating = int(input("Rate the last answer (1-5): "))
1006
+ bot.provide_feedback(last_question, rating)
1007
+ except ValueError:
1008
+ print("Invalid rating")
1009
+ else:
1010
+ print("No previous question to rate")
1011
+ continue
1012
+
1013
+ if not user_input.strip():
1014
+ continue
1015
+
1016
+ try:
1017
+ last_question = user_input
1018
+ answer = bot.ask(user_input)
1019
+ print(f"\nBot: {answer}\n")
1020
+ except Exception as e:
1021
+ print(f"Error: {e}\n")
data/policies.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0b5e498b8231c3c4fd80aeba1fe10b96627c40974e11d0be82b5fe47a83900b
3
+ size 338325
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
+ PyPDF2==3.0.1
5
+ faiss-cpu==1.7.4
6
+ numpy==1.24.3
7
+ sentence-transformers==2.2.2
8
+ requests==2.31.0
9
+ pandas==2.0.3
10
+ tabula-py==2.9.0
11
+ huggingface-hub==0.19.4
12
+ python-multipart==0.0.6