Jaheen07 commited on
Commit
44af2ac
·
verified ·
1 Parent(s): 68f3419

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +1014 -1014
chatbot.py CHANGED
@@ -1,1015 +1,1015 @@
1
- # RAG Chatbot with Separate Table and Text Processing + Reinforcement Learning from Chat History
2
- import PyPDF2
3
- import faiss
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer
6
- from huggingface_hub import InferenceClient
7
- from typing import List, Tuple, Dict
8
- import json
9
- import re
10
- import pandas as pd
11
- import tabula.io as tabula
12
- import os
13
- import pickle
14
- from datetime import datetime
15
- from collections import Counter
16
-
17
-
18
- class RAGChatbot:
19
- def __init__(self, pdf_path: str, hf_token: str):
20
- self.pdf_path = pdf_path
21
- self.hf_token = hf_token
22
- self.chunks = []
23
- self.chunk_metadata = []
24
- self.index = None
25
- self.embeddings_model = None
26
- self.llm_client = None
27
- self.chat_history = []
28
- self.output_dir = "./"
29
- self.table_csv_path = None
30
- self.text_chunks_path = None
31
- self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
32
-
33
- # Chat history embeddings and index
34
- self.chat_embeddings = []
35
- self.chat_index = None
36
- self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
37
-
38
- # Learning statistics
39
- self.query_patterns = Counter()
40
- self.feedback_scores = {}
41
- self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
42
-
43
- # ADD THIS NEW SECTION:
44
- self.conversation_context = {
45
- 'current_employee': None,
46
- 'last_mentioned_entities': []
47
- }
48
-
49
- os.makedirs(self.output_dir, exist_ok=True)
50
-
51
- # Load existing chat history and learning data
52
- self._load_chat_history()
53
- self._load_learning_stats()
54
-
55
- self._setup()
56
-
57
- # Build chat history index after setup
58
- self._build_chat_history_index()
59
-
60
- def _load_chat_history(self):
61
- """Load chat history from file if exists"""
62
- if os.path.exists(self.history_file):
63
- try:
64
- with open(self.history_file, 'rb') as f:
65
- self.chat_history = pickle.load(f)
66
- print(f"Loaded {len(self.chat_history)} previous conversations")
67
- except Exception as e:
68
- print(f"Could not load chat history: {e}")
69
- self.chat_history = []
70
- else:
71
- self.chat_history = []
72
-
73
- def _save_chat_history(self):
74
- """Save chat history to file"""
75
- try:
76
- with open(self.history_file, 'wb') as f:
77
- pickle.dump(self.chat_history, f)
78
- except Exception as e:
79
- print(f"Could not save chat history: {e}")
80
-
81
- def _load_learning_stats(self):
82
- """Load learning statistics"""
83
- if os.path.exists(self.stats_file):
84
- try:
85
- with open(self.stats_file, 'rb') as f:
86
- data = pickle.load(f)
87
- self.query_patterns = data.get('query_patterns', Counter())
88
- self.feedback_scores = data.get('feedback_scores', {})
89
- print(f"Loaded learning statistics: {len(self.query_patterns)} patterns tracked")
90
- except Exception as e:
91
- print(f"Could not load learning stats: {e}")
92
- self.query_patterns = Counter()
93
- self.feedback_scores = {}
94
- else:
95
- self.query_patterns = Counter()
96
- self.feedback_scores = {}
97
-
98
- def _save_learning_stats(self):
99
- """Save learning statistics"""
100
- try:
101
- with open(self.stats_file, 'wb') as f:
102
- pickle.dump({
103
- 'query_patterns': self.query_patterns,
104
- 'feedback_scores': self.feedback_scores
105
- }, f)
106
- except Exception as e:
107
- print(f"Could not save learning stats: {e}")
108
-
109
- def _build_chat_history_index(self):
110
- """Build FAISS index from chat history for semantic search"""
111
- if len(self.chat_history) == 0:
112
- print("No chat history to index")
113
- return
114
-
115
- print(f"Building semantic index for {len(self.chat_history)} past conversations...")
116
-
117
- # Create embeddings for all past Q&A pairs
118
- chat_texts = []
119
- for entry in self.chat_history:
120
- # Combine question and answer for better context
121
- combined_text = f"Q: {entry['question']}\nA: {entry['answer']}"
122
- chat_texts.append(combined_text)
123
-
124
- # Generate embeddings
125
- self.chat_embeddings = self.embeddings_model.encode(chat_texts, show_progress_bar=True)
126
-
127
- # Build FAISS index
128
- dimension = self.chat_embeddings.shape[1]
129
- self.chat_index = faiss.IndexFlatL2(dimension)
130
- self.chat_index.add(np.array(self.chat_embeddings).astype('float32'))
131
-
132
- # Save embeddings
133
- try:
134
- with open(self.chat_embedding_file, 'wb') as f:
135
- pickle.dump(self.chat_embeddings, f)
136
- except Exception as e:
137
- print(f"Could not save chat embeddings: {e}")
138
-
139
- print(f"Chat history index built successfully")
140
-
141
- def _search_chat_history(self, query: str, k: int = 5) -> List[Dict]:
142
- """Search through past conversations semantically"""
143
- if self.chat_index is None or len(self.chat_history) == 0:
144
- return []
145
-
146
- # Encode query
147
- query_embedding = self.embeddings_model.encode([query])
148
-
149
- # Search
150
- distances, indices = self.chat_index.search(
151
- np.array(query_embedding).astype('float32'),
152
- min(k, len(self.chat_history))
153
- )
154
-
155
- # Return relevant past conversations
156
- relevant_chats = []
157
- for idx, distance in zip(indices[0], distances[0]):
158
- if distance < 1.5: # Similarity threshold
159
- relevant_chats.append({
160
- 'chat': self.chat_history[idx],
161
- 'similarity_score': float(distance)
162
- })
163
-
164
- return relevant_chats
165
-
166
- def _extract_entities_from_query(self, query: str) -> Dict:
167
- """Extract names and entities from query"""
168
- query_lower = query.lower()
169
-
170
- # Check for pronouns that need context
171
- has_pronoun = bool(re.search(r'\b(his|her|their|he|she|they|him|them)\b', query_lower))
172
-
173
- # Try to extract names (capitalize words that might be names)
174
- potential_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', query)
175
-
176
- return {
177
- 'has_pronoun': has_pronoun,
178
- 'names': potential_names
179
- }
180
-
181
- def _update_conversation_context(self, question: str, answer: str):
182
- """Update context tracking based on conversation"""
183
- # Extract names from question
184
- names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
185
-
186
- # Extract names from answer
187
- answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
188
-
189
- # Update current employee if employee was mentioned
190
- if 'employee' in answer.lower() or 'working' in answer.lower():
191
- all_names = names + answer_names
192
- if all_names:
193
- self.conversation_context['current_employee'] = all_names[0]
194
- # Keep last 5 mentioned entities
195
- self.conversation_context['last_mentioned_entities'] = (
196
- all_names[:5] if len(all_names) <= 5
197
- else self.conversation_context['last_mentioned_entities'][-4:] + [all_names[0]]
198
- )
199
-
200
- def _resolve_pronouns(self, query: str) -> str:
201
- """Replace pronouns with actual entity names from context"""
202
- entities = self._extract_entities_from_query(query)
203
-
204
- if entities['has_pronoun'] and self.conversation_context['current_employee']:
205
- current_name = self.conversation_context['current_employee']
206
-
207
- # Replace pronouns with the current employee name
208
- query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
209
- query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
210
- query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
211
- query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
212
-
213
- return query
214
-
215
-
216
- def _extract_query_pattern(self, query: str) -> str:
217
- """Extract pattern from query for learning"""
218
- query_lower = query.lower()
219
-
220
- # Detect common patterns
221
- patterns = []
222
-
223
- if re.search(r'\bhow many\b', query_lower):
224
- patterns.append('count_query')
225
- if re.search(r'\bwho\b', query_lower):
226
- patterns.append('who_query')
227
- if re.search(r'\bwhat\b', query_lower):
228
- patterns.append('what_query')
229
- if re.search(r'\bwhen\b', query_lower):
230
- patterns.append('when_query')
231
- if re.search(r'\bwhere\b', query_lower):
232
- patterns.append('where_query')
233
- if re.search(r'\blist\b|\ball\b', query_lower):
234
- patterns.append('list_query')
235
- if re.search(r'\bcalculate\b|\bsum\b|\btotal\b|\baverage\b', query_lower):
236
- patterns.append('calculation_query')
237
- if re.search(r'\bemployee\b|\bstaff\b|\bworker\b', query_lower):
238
- patterns.append('employee_query')
239
- if re.search(r'\bpolicy\b|\brule\b|\bguideline\b', query_lower):
240
- patterns.append('policy_query')
241
-
242
- return '|'.join(patterns) if patterns else 'general_query'
243
-
244
- def _load_pdf_text(self) -> str:
245
- """Load text from PDF"""
246
- text = ""
247
- with open(self.pdf_path, 'rb') as file:
248
- pdf_reader = PyPDF2.PdfReader(file)
249
- for page in pdf_reader.pages:
250
- text += page.extract_text()
251
- return text
252
-
253
- def _extract_and_merge_tables(self) -> str:
254
- """Extract all tables from PDF and merge into single CSV"""
255
- try:
256
- print("Extracting tables from PDF...")
257
-
258
- # Extract all tables
259
- dfs = tabula.read_pdf(self.pdf_path, pages="all", multiple_tables=True)
260
-
261
- if not dfs or len(dfs) == 0:
262
- print("No tables found in PDF")
263
- return None
264
-
265
- print(f"Found {len(dfs)} tables")
266
-
267
- # The first table has headers
268
- merged_df = dfs[0]
269
-
270
- # Append rest of the tables
271
- for i in range(1, len(dfs)):
272
- # Set the column names to match the first table
273
- dfs[i].columns = merged_df.columns
274
- # Append rows
275
- merged_df = pd.concat([merged_df, dfs[i]], ignore_index=True)
276
-
277
- # Save merged table
278
- csv_path = os.path.join(self.output_dir, "merged_employee_tables.csv")
279
- merged_df.to_csv(csv_path, index=False)
280
-
281
- print(f"Merged {len(dfs)} tables into {csv_path}")
282
- print(f"Total rows: {len(merged_df)}")
283
- print(f"Columns: {list(merged_df.columns)}")
284
-
285
- return csv_path
286
-
287
- except Exception as e:
288
- print(f"Table extraction failed: {e}")
289
- return None
290
-
291
- def _save_table_chunks(self, table_chunks: List[Dict]) -> str:
292
- """Save table chunks (full table + row chunks) to a text file"""
293
- save_path = os.path.join(self.output_dir, "table_chunks.txt")
294
-
295
- with open(save_path, 'w', encoding='utf-8') as f:
296
- f.write(f"Total Table Chunks: {len(table_chunks)}\n")
297
- f.write("=" * 80 + "\n\n")
298
-
299
- for i, chunk in enumerate(table_chunks):
300
- f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
301
- f.write("-" * 80 + "\n")
302
- f.write(chunk['content'])
303
- f.write("\n\n" + "=" * 80 + "\n\n")
304
-
305
- print(f"Saved {len(table_chunks)} table chunks to {save_path}")
306
- return save_path
307
-
308
- def _detect_table_regions_in_text(self, text: str) -> List[Tuple[int, int]]:
309
- """Detect start and end positions of table regions in text"""
310
- lines = text.split('\n')
311
- table_regions = []
312
- start_idx = None
313
-
314
- for i, line in enumerate(lines):
315
- is_table_line = (
316
- '@' in line or
317
- re.search(r'\b(A|B|AB|O)[+-]\b', line) or
318
- re.search(r'\s{3,}', line) or
319
- re.search(r'Employee Name|Email|Position|Table|Blood Group', line, re.IGNORECASE)
320
- )
321
-
322
- if is_table_line:
323
- if start_idx is None:
324
- start_idx = i
325
- else:
326
- if start_idx is not None:
327
- # End of table region
328
- if i - start_idx > 3: # Only consider tables with 3+ lines
329
- table_regions.append((start_idx, i))
330
- start_idx = None
331
-
332
- # Handle last table if exists
333
- if start_idx is not None and len(lines) - start_idx > 3:
334
- table_regions.append((start_idx, len(lines)))
335
-
336
- return table_regions
337
-
338
- def _remove_table_text(self, text: str) -> str:
339
- """Remove table content from text"""
340
- lines = text.split('\n')
341
- table_regions = self._detect_table_regions_in_text(text)
342
-
343
- # Create set of line indices to remove
344
- lines_to_remove = set()
345
- for start, end in table_regions:
346
- for i in range(start, end):
347
- lines_to_remove.add(i)
348
-
349
- # Keep only non-table lines
350
- clean_lines = [line for i, line in enumerate(lines) if i not in lines_to_remove]
351
-
352
- return '\n'.join(clean_lines)
353
-
354
- def _chunk_text_content(self, text: str) -> List[Dict]:
355
- """Chunk text content (Q&A pairs and other text)"""
356
- chunks = []
357
-
358
- # Remove table text
359
- clean_text = self._remove_table_text(text)
360
-
361
- # Split by ###Question###
362
- qa_pairs = clean_text.split('###Question###')
363
-
364
- for i, qa in enumerate(qa_pairs):
365
- if not qa.strip():
366
- continue
367
-
368
- if '###Answer###' in qa:
369
- chunk_text = '###Question###' + qa
370
- if len(chunk_text) > 50:
371
- chunks.append({
372
- 'content': chunk_text,
373
- 'type': 'qa',
374
- 'source': 'text_content',
375
- 'chunk_id': f'qa_{i}'
376
- })
377
-
378
- # Also create chunks from sections (for non-Q&A content)
379
- sections = re.split(r'\n\n+', clean_text)
380
- for i, section in enumerate(sections):
381
- section = section.strip()
382
- if len(section) > 200 and '###Question###' not in section:
383
- chunks.append({
384
- 'content': section,
385
- 'type': 'text',
386
- 'source': 'text_content',
387
- 'chunk_id': f'text_{i}'
388
- })
389
-
390
- return chunks
391
-
392
- def _save_text_chunks(self, chunks: List[Dict]) -> str:
393
- """Save text chunks to file"""
394
- text_path = os.path.join(self.output_dir, "text_chunks.txt")
395
-
396
- with open(text_path, 'w', encoding='utf-8') as f:
397
- f.write(f"Total Text Chunks: {len(chunks)}\n")
398
- f.write("=" * 80 + "\n\n")
399
-
400
- for i, chunk in enumerate(chunks):
401
- f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
402
- f.write("-" * 80 + "\n")
403
- f.write(chunk['content'])
404
- f.write("\n\n" + "=" * 80 + "\n\n")
405
-
406
- print(f"Saved {len(chunks)} text chunks to {text_path}")
407
- return text_path
408
-
409
- def _load_csv_as_text(self, csv_path: str) -> str:
410
- """Load CSV and convert to readable text format"""
411
- try:
412
- df = pd.read_csv(csv_path)
413
- text = f"[EMPLOYEE TABLE DATA]\n"
414
- text += f"Total Employees: {len(df)}\n\n"
415
- text += df.to_string(index=False)
416
- return text
417
- except Exception as e:
418
- print(f"Error loading CSV: {e}")
419
- return ""
420
-
421
- def _create_table_chunks(self, csv_path: str) -> List[Dict]:
422
- """Create chunks from CSV table"""
423
- chunks = []
424
-
425
- try:
426
- df = pd.read_csv(csv_path)
427
-
428
- # Create one chunk with full table overview
429
- full_table_text = f"[COMPLETE EMPLOYEE TABLE]\n"
430
- full_table_text += f"Total Employees: {len(df)}\n"
431
- full_table_text += f"Columns: {', '.join(df.columns)}\n\n"
432
- full_table_text += df.to_string(index=False)
433
-
434
- chunks.append({
435
- 'content': full_table_text,
436
- 'type': 'table_full',
437
- 'source': 'employee_table.csv',
438
- 'chunk_id': 'table_full'
439
- })
440
-
441
- # Create chunks for each row (employee)
442
- for idx, row in df.iterrows():
443
- row_text = f"[EMPLOYEE RECORD {idx + 1}]\n"
444
- for col in df.columns:
445
- row_text += f"{col}: {row[col]}\n"
446
-
447
- chunks.append({
448
- 'content': row_text,
449
- 'type': 'table_row',
450
- 'source': 'employee_table.csv',
451
- 'chunk_id': f'employee_{idx}'
452
- })
453
-
454
- print(f"Created {len(chunks)} chunks from table ({len(df)} employee records + 1 full table)")
455
-
456
- except Exception as e:
457
- print(f"Error creating table chunks: {e}")
458
-
459
- return chunks
460
-
461
- def _save_manifest(self, all_chunks: List[Dict]):
462
- """Save manifest of all chunks"""
463
- manifest = {
464
- 'total_chunks': len(all_chunks),
465
- 'chunks_by_type': {
466
- 'qa': sum(1 for c in all_chunks if c['type'] == 'qa'),
467
- 'text': sum(1 for c in all_chunks if c['type'] == 'text'),
468
- 'table_full': sum(1 for c in all_chunks if c['type'] == 'table_full'),
469
- 'table_row': sum(1 for c in all_chunks if c['type'] == 'table_row')
470
- },
471
- 'files_created': {
472
- 'table_csv': self.table_csv_path,
473
- 'text_chunks': self.text_chunks_path
474
- },
475
- 'chunk_details': [
476
- {
477
- 'chunk_id': c['chunk_id'],
478
- 'type': c['type'],
479
- 'source': c['source'],
480
- 'length': len(c['content'])
481
- }
482
- for c in all_chunks
483
- ]
484
- }
485
-
486
- manifest_path = os.path.join(self.output_dir, 'chunk_manifest.json')
487
- with open(manifest_path, 'w', encoding='utf-8') as f:
488
- json.dump(manifest, f, indent=2, ensure_ascii=False)
489
-
490
- print(f"Saved manifest to {manifest_path}")
491
- return manifest_path
492
-
493
- def _resolve_pronouns_for_session(self, query: str, conversation_context: Dict) -> str:
494
- """Resolve pronouns using session-specific context"""
495
- entities = self._extract_entities_from_query(query)
496
-
497
- if entities['has_pronoun'] and conversation_context.get('current_employee'):
498
- current_name = conversation_context['current_employee']
499
-
500
- query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
501
- query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
502
- query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
503
- query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
504
-
505
- return query
506
-
507
- def _search_session_history(self, query: str, session_history: List[Dict], k: int = 5) -> List[Dict]:
508
- """Search through session-specific history"""
509
- if not session_history:
510
- return []
511
-
512
- chat_texts = [f"Q: {entry['question']}\nA: {entry['answer']}" for entry in session_history]
513
-
514
- if not chat_texts:
515
- return []
516
-
517
- chat_embeddings = self.embeddings_model.encode(chat_texts)
518
-
519
- dimension = chat_embeddings.shape[1]
520
- temp_index = faiss.IndexFlatL2(dimension)
521
- temp_index.add(np.array(chat_embeddings).astype('float32'))
522
-
523
- query_embedding = self.embeddings_model.encode([query])
524
- distances, indices = temp_index.search(
525
- np.array(query_embedding).astype('float32'),
526
- min(k, len(session_history))
527
- )
528
-
529
- relevant_chats = []
530
- for idx, distance in zip(indices[0], distances[0]):
531
- if distance < 1.5:
532
- relevant_chats.append({
533
- 'chat': session_history[idx],
534
- 'similarity_score': float(distance)
535
- })
536
-
537
- return relevant_chats
538
-
539
- def _build_prompt_for_session(self, query: str, retrieved_data: List[Tuple[str, Dict]],
540
- relevant_past_chats: List[Dict], session_history: List[Dict],
541
- conversation_context: Dict) -> str:
542
- """Build prompt using session-specific data"""
543
-
544
- employee_records = []
545
- full_table = []
546
- qa_context = []
547
- text_context = []
548
-
549
- for content, metadata in retrieved_data:
550
- if metadata['type'] == 'table_row':
551
- employee_records.append(content)
552
- elif metadata['type'] == 'table_full':
553
- full_table.append(content)
554
- elif metadata['type'] == 'qa':
555
- qa_context.append(content)
556
- elif metadata['type'] == 'text':
557
- text_context.append(content)
558
-
559
- context_text = ""
560
- if full_table:
561
- context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
562
- if employee_records:
563
- context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
564
- if qa_context:
565
- context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
566
- if text_context:
567
- context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
568
-
569
- context_memory = ""
570
- if conversation_context.get('current_employee'):
571
- context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
572
- context_memory += f"Currently discussing: {conversation_context['current_employee']}\n"
573
- if conversation_context.get('last_mentioned_entities'):
574
- context_memory += f"Recently mentioned: {', '.join(conversation_context['last_mentioned_entities'])}\n"
575
- context_memory += "\n"
576
-
577
- past_context = ""
578
- if relevant_past_chats:
579
- past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
580
- for i, chat_info in enumerate(relevant_past_chats[:3], 1):
581
- chat = chat_info['chat']
582
- past_context += f"\n[Past Q&A {i}]:\n"
583
- past_context += f"Previous Question: {chat['question']}\n"
584
- past_context += f"Previous Answer: {chat['answer']}\n"
585
- past_context += "\n"
586
-
587
- history_text = ""
588
- for entry in session_history[-10:]:
589
- history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
590
-
591
- prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
592
-
593
- IMPORTANT INSTRUCTIONS:
594
- - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
595
- - For employee-related queries, use the employee data provided
596
- - If you find any name from user input, always look into the EMPLOYEE TABLE first
597
- - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in THIS USER's recent conversation
598
- - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
599
- - Be careful not to give all employee information - only answer what was asked
600
- - For counting or calculations, use the table data
601
- - For policy questions, use the Q&A knowledge base
602
- - Provide specific, accurate answers based on the context
603
- - If information is not in the context, say "I don't have this information"
604
- - Round up any fractional numbers in calculations
605
-
606
- Context:
607
- {context_text}
608
-
609
- {context_memory}
610
-
611
- {past_context}
612
-
613
- Recent conversation:
614
- {history_text}
615
-
616
- User Question: {query}
617
-
618
- Answer based on the context above. Be specific and accurate.[/INST]"""
619
-
620
- return prompt
621
-
622
- def _update_conversation_context_for_session(self, question: str, answer: str, conversation_context: Dict):
623
- """Update session-specific conversation context"""
624
- names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
625
- answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
626
-
627
- if 'employee' in answer.lower() or 'working' in answer.lower():
628
- all_names = names + answer_names
629
- if all_names:
630
- conversation_context['current_employee'] = all_names[0]
631
- conversation_context['last_mentioned_entities'] = (
632
- all_names[:5] if len(all_names) <= 5
633
- else conversation_context.get('last_mentioned_entities', [])[-4:] + [all_names[0]]
634
- )
635
-
636
- def _setup(self):
637
- print("\n" + "=" * 80)
638
- print("STEP 1: Loading PDF")
639
- print("=" * 80)
640
-
641
- text = self._load_pdf_text()
642
- print(f"Loaded PDF with {len(text)} characters")
643
-
644
- print("\n" + "=" * 80)
645
- print("STEP 2: Extracting and Merging Tables")
646
- print("=" * 80)
647
-
648
- self.table_csv_path = self._extract_and_merge_tables()
649
-
650
- print("\n" + "=" * 80)
651
- print("STEP 3: Chunking Text Content (Removing Tables)")
652
- print("=" * 80)
653
-
654
- text_chunks = self._chunk_text_content(text)
655
- self.text_chunks_path = self._save_text_chunks(text_chunks)
656
-
657
- print("\n" + "=" * 80)
658
- print("STEP 4: Creating Final Chunks")
659
- print("=" * 80)
660
-
661
- all_chunks = []
662
-
663
- # Add text chunks
664
- all_chunks.extend(text_chunks)
665
-
666
- # Add table chunks
667
- if self.table_csv_path:
668
- table_chunks = self._create_table_chunks(self.table_csv_path)
669
- all_chunks.extend(table_chunks)
670
- # Save chunked table text to file
671
- self._save_table_chunks(table_chunks)
672
-
673
- # Extract content and metadata
674
- self.chunks = [c['content'] for c in all_chunks]
675
- self.chunk_metadata = all_chunks
676
-
677
- print(f"\nTotal chunks created: {len(self.chunks)}")
678
- print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
679
- print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
680
- print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
681
- print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
682
-
683
- # Save manifest
684
- self._save_manifest(all_chunks)
685
-
686
- print("\n" + "=" * 80)
687
- print("STEP 5: Creating Embeddings")
688
- print("=" * 80)
689
-
690
- print("Loading embedding model...")
691
- self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
692
-
693
- print("Creating embeddings for all chunks...")
694
- embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
695
-
696
- print("Building FAISS index...")
697
- dimension = embeddings.shape[1]
698
- self.index = faiss.IndexFlatL2(dimension)
699
- self.index.add(np.array(embeddings).astype('float32'))
700
-
701
- print("\n" + "=" * 80)
702
- print("STEP 6: Initializing LLM")
703
- print("=" * 80)
704
-
705
- self.llm_client = InferenceClient(token=self.hf_token)
706
-
707
- print("\n" + "=" * 80)
708
- print("SETUP COMPLETE!")
709
- print("=" * 80)
710
- print(f"Files created in: {self.output_dir}/")
711
- print(f" - {os.path.basename(self.table_csv_path) if self.table_csv_path else 'No table CSV'}")
712
- print(f" - {os.path.basename(self.text_chunks_path)}")
713
- print(f" - chunk_manifest.json")
714
- print(f" - {os.path.basename(self.history_file)}")
715
- print("=" * 80 + "\n")
716
-
717
- def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
718
- """Retrieve relevant chunks with metadata"""
719
- query_embedding = self.embeddings_model.encode([query])
720
- distances, indices = self.index.search(np.array(query_embedding).astype('float32'), k)
721
-
722
- results = []
723
- for idx in indices[0]:
724
- results.append((self.chunks[idx], self.chunk_metadata[idx]))
725
-
726
- return results
727
-
728
- def _build_prompt(self, query: str, retrieved_data: List[Tuple[str, Dict]], relevant_past_chats: List[Dict]) -> str:
729
- """Build prompt with retrieved context and learned information from past chats"""
730
-
731
- # Separate different types of context
732
- employee_records = []
733
- full_table = []
734
- qa_context = []
735
- text_context = []
736
-
737
- for content, metadata in retrieved_data:
738
- if metadata['type'] == 'table_row':
739
- employee_records.append(content)
740
- elif metadata['type'] == 'table_full':
741
- full_table.append(content)
742
- elif metadata['type'] == 'qa':
743
- qa_context.append(content)
744
- elif metadata['type'] == 'text':
745
- text_context.append(content)
746
-
747
- # Build context sections
748
- context_text = ""
749
-
750
- if full_table:
751
- context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
752
-
753
- if employee_records:
754
- context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
755
-
756
- if qa_context:
757
- context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
758
-
759
- if text_context:
760
- context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
761
-
762
- # ADD THIS NEW SECTION:
763
- context_memory = ""
764
- if self.conversation_context['current_employee']:
765
- context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
766
- context_memory += f"Currently discussing: {self.conversation_context['current_employee']}\n"
767
- if self.conversation_context['last_mentioned_entities']:
768
- context_memory += f"Recently mentioned: {', '.join(self.conversation_context['last_mentioned_entities'])}\n"
769
- context_memory += "\n"
770
-
771
- # Build relevant past conversations (learning from history)
772
- past_context = ""
773
- if relevant_past_chats:
774
- past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
775
- for i, chat_info in enumerate(relevant_past_chats[:3], 1):
776
- chat = chat_info['chat']
777
- past_context += f"\n[Past Q&A {i}]:\n"
778
- past_context += f"Previous Question: {chat['question']}\n"
779
- past_context += f"Previous Answer: {chat['answer']}\n"
780
- past_context += "\n"
781
-
782
- # CHANGE THIS LINE from [-3:] to [-10:]:
783
- history_text = ""
784
- for entry in self.chat_history: # Changed from -3 to -10
785
- history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
786
-
787
- prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
788
-
789
- IMPORTANT INSTRUCTIONS:
790
- - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
791
- - For employee-related queries, use the employee data provided
792
- - If you find any name from user input, always look into the EMPLOYEE TABLE first. If you still can't find, then you can go for chunked text.
793
- - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in recent conversation
794
- - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
795
- - While your answer is related to an employee, be careful of not giving all the information of the employee. Just give the information user asked.
796
- - For counting or calculations, use the table data
797
- - For policy questions, use the Q&A knowledge base
798
- - LEARN from relevant past conversations - if similar questions were asked before, maintain consistency
799
- - Use patterns from past interactions to improve answer quality
800
- - Provide specific, accurate answers based on the context
801
- - If you need to count employees or perform calculations, do it carefully from the data
802
- - If information is not in the context, just say "I don't have this information in the provided documents"
803
- - While performing any type of mathematical calculation, always round up any fractional number.
804
-
805
- Context:
806
- {context_text}
807
-
808
- {context_memory}
809
-
810
- {past_context}
811
-
812
- Recent conversation:
813
- {history_text}
814
-
815
- User Question: {query}
816
-
817
- Answer based on the context above. Be specific and accurate. But don't always start with "based on the context"[/INST]"""
818
-
819
- return prompt
820
-
821
- def ask(self, question: str) -> str:
822
- """Ask a question to the chatbot with learning from past conversations"""
823
- if question.lower() in ["reset data", "reset"]:
824
- self.chat_history = []
825
- self.chat_embeddings = []
826
- self.chat_index = None
827
- self.conversation_context = {'current_employee': None, 'last_mentioned_entities': []} # ADD THIS LINE
828
- self._save_chat_history()
829
- return "Chat history has been reset."
830
-
831
- # ADD THIS LINE:
832
- resolved_question = self._resolve_pronouns(question)
833
-
834
- # CHANGE 'question' to 'resolved_question' in next line:
835
- pattern = self._extract_query_pattern(resolved_question)
836
- self.query_patterns[pattern] += 1
837
-
838
- # CHANGE 'question' to 'resolved_question':
839
- relevant_past_chats = self._search_chat_history(resolved_question, k=10)
840
-
841
- # CHANGE 'question' to 'resolved_question':
842
- retrieved_data = self._retrieve(resolved_question, k=20)
843
-
844
- # CHANGE 'question' to 'resolved_question':
845
- prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
846
-
847
- # Generate response
848
- messages = [{"role": "user", "content": prompt}]
849
-
850
- response = self.llm_client.chat_completion(
851
- messages=messages,
852
- model="meta-llama/Llama-3.1-8B-Instruct",
853
- max_tokens=512,
854
- temperature=0.3
855
- )
856
-
857
- answer = response.choices[0].message.content
858
-
859
- # ADD THIS LINE:
860
- self._update_conversation_context(question, answer)
861
-
862
- # Store in history with timestamp and metadata
863
- chat_entry = {
864
- 'timestamp': datetime.now().isoformat(),
865
- 'question': question,
866
- 'answer': answer,
867
- 'pattern': pattern,
868
- 'used_past_context': len(relevant_past_chats) > 0
869
- }
870
-
871
- self.chat_history.append(chat_entry)
872
-
873
- # Update chat history index with new conversation
874
- new_text = f"Q: {question}\nA: {answer}"
875
- new_embedding = self.embeddings_model.encode([new_text])
876
-
877
- if self.chat_index is None:
878
- dimension = new_embedding.shape[1]
879
- self.chat_index = faiss.IndexFlatL2(dimension)
880
- self.chat_embeddings = new_embedding
881
- else:
882
- self.chat_embeddings = np.vstack([self.chat_embeddings, new_embedding])
883
-
884
- self.chat_index.add(np.array(new_embedding).astype('float32'))
885
-
886
- # Save to disk after each conversation
887
- self._save_chat_history()
888
- self._save_learning_stats()
889
-
890
- return answer
891
-
892
- def provide_feedback(self, question: str, rating: int):
893
- """Allow user to rate responses for reinforcement learning (1-5 scale)"""
894
- if 1 <= rating <= 5:
895
- # Find the most recent occurrence of this question
896
- for i in range(len(self.chat_history) - 1, -1, -1):
897
- if self.chat_history[i]['question'] == question:
898
- chat_id = f"{i}_{self.chat_history[i]['timestamp']}"
899
- self.feedback_scores[chat_id] = rating
900
- self._save_learning_stats()
901
- print(f"Feedback recorded: {rating}/5")
902
- return
903
- print("Question not found in recent history")
904
- else:
905
- print("Rating must be between 1 and 5")
906
-
907
- def get_learning_insights(self) -> Dict:
908
- """Get insights about what the chatbot has learned"""
909
- total_conversations = len(self.chat_history)
910
- conversations_with_past_context = sum(
911
- 1 for c in self.chat_history if c.get('used_past_context', False)
912
- )
913
-
914
- avg_feedback = 0
915
- if self.feedback_scores:
916
- avg_feedback = sum(self.feedback_scores.values()) / len(self.feedback_scores)
917
-
918
- return {
919
- 'total_conversations': total_conversations,
920
- 'conversations_using_past_context': conversations_with_past_context,
921
- 'query_patterns': dict(self.query_patterns.most_common(10)),
922
- 'total_feedback_entries': len(self.feedback_scores),
923
- 'average_feedback_score': round(avg_feedback, 2)
924
- }
925
-
926
- def get_history(self) -> List[Dict]:
927
- """Get chat history"""
928
- return self.chat_history
929
-
930
- def display_stats(self):
931
- """Display system statistics"""
932
- qa_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'qa')
933
- text_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'text')
934
- table_full = sum(1 for c in self.chunk_metadata if c['type'] == 'table_full')
935
- table_rows = sum(1 for c in self.chunk_metadata if c['type'] == 'table_row')
936
-
937
- insights = self.get_learning_insights()
938
-
939
- print(f"\n{'=' * 80}")
940
- print("CHATBOT STATISTICS")
941
- print(f"{'=' * 80}")
942
- print(f"Total chunks: {len(self.chunks)}")
943
- print(f" - Q&A chunks: {qa_chunks}")
944
- print(f" - Text chunks: {text_chunks}")
945
- print(f" - Full table: {table_full}")
946
- print(f" - Employee records: {table_rows}")
947
- print(f"\nLEARNING STATISTICS:")
948
- print(f" - Total conversations: {insights['total_conversations']}")
949
- print(f" - Conversations using past context: {insights['conversations_using_past_context']}")
950
- print(f" - Total feedback entries: {insights['total_feedback_entries']}")
951
- print(f" - Average feedback score: {insights['average_feedback_score']}/5")
952
- print(f"\nTop query patterns:")
953
- for pattern, count in list(insights['query_patterns'].items())[:5]:
954
- print(f" - {pattern}: {count}")
955
- print(f"\nOutput directory: {self.output_dir}/")
956
- print(f"Table CSV: {os.path.basename(self.table_csv_path) if self.table_csv_path else 'None'}")
957
- print(f"Text chunks: {os.path.basename(self.text_chunks_path)}")
958
- print(f"History file: {os.path.basename(self.history_file)}")
959
- print(f"Learning stats: {os.path.basename(self.stats_file)}")
960
- print(f"{'=' * 80}\n")
961
-
962
-
963
- # Main execution
964
- if __name__ == "__main__":
965
- # Configuration
966
- PDF_PATH = "data/policies.pdf"
967
- HF_TOKEN = os.getenv("HF_TOKEN")
968
-
969
- if not HF_TOKEN:
970
- raise ValueError("HF_TOKEN environment variable not set")
971
-
972
- # Initialize chatbot
973
- print("\nInitializing RAG Chatbot with Learning Capabilities...")
974
- bot = RAGChatbot(PDF_PATH, HF_TOKEN)
975
-
976
- # Display statistics
977
- bot.display_stats()
978
-
979
- # Chat loop
980
- print("Chatbot ready! Type 'exit' to quit, 'stats' for learning insights, or 'feedback' to rate last answer.\n")
981
- last_question = None
982
-
983
- while True:
984
- user_input = input("You: ")
985
-
986
- if user_input.lower() in ['exit', 'quit', 'q']:
987
- print("Goodbye!")
988
- break
989
-
990
- if user_input.lower() == 'stats':
991
- insights = bot.get_learning_insights()
992
- print("\nLearning Insights:")
993
- print(json.dumps(insights, indent=2))
994
- continue
995
-
996
- if user_input.lower() == 'feedback':
997
- if last_question:
998
- try:
999
- rating = int(input("Rate the last answer (1-5): "))
1000
- bot.provide_feedback(last_question, rating)
1001
- except ValueError:
1002
- print("Invalid rating")
1003
- else:
1004
- print("No previous question to rate")
1005
- continue
1006
-
1007
- if not user_input.strip():
1008
- continue
1009
-
1010
- try:
1011
- last_question = user_input
1012
- answer = bot.ask(user_input)
1013
- print(f"\nBot: {answer}\n")
1014
- except Exception as e:
1015
  print(f"Error: {e}\n")
 
1
+ # RAG Chatbot with Separate Table and Text Processing + Reinforcement Learning from Chat History
2
+ import PyPDF2
3
+ import faiss
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from huggingface_hub import InferenceClient
7
+ from typing import List, Tuple, Dict
8
+ import json
9
+ import re
10
+ import pandas as pd
11
+ import tabula.io as tabula
12
+ import os
13
+ import pickle
14
+ from datetime import datetime
15
+ from collections import Counter
16
+
17
+
18
+ class RAGChatbot:
19
+ def __init__(self, pdf_path: str, hf_token: str):
20
+ self.pdf_path = pdf_path
21
+ self.hf_token = hf_token
22
+ self.chunks = []
23
+ self.chunk_metadata = []
24
+ self.index = None
25
+ self.embeddings_model = None
26
+ self.llm_client = None
27
+ self.chat_history = []
28
+ self.output_dir = "./"
29
+ self.table_csv_path = None
30
+ self.text_chunks_path = None
31
+ self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
32
+
33
+ # Chat history embeddings and index
34
+ self.chat_embeddings = []
35
+ self.chat_index = None
36
+ self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
37
+
38
+ # Learning statistics
39
+ self.query_patterns = Counter()
40
+ self.feedback_scores = {}
41
+ self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
42
+
43
+ # ADD THIS NEW SECTION:
44
+ self.conversation_context = {
45
+ 'current_employee': None,
46
+ 'last_mentioned_entities': []
47
+ }
48
+
49
+ os.makedirs(self.output_dir, exist_ok=True)
50
+
51
+ # Load existing chat history and learning data
52
+ self._load_chat_history()
53
+ self._load_learning_stats()
54
+
55
+ self._setup()
56
+
57
+ # Build chat history index after setup
58
+ self._build_chat_history_index()
59
+
60
+ def _load_chat_history(self):
61
+ """Load chat history from file if exists"""
62
+ if os.path.exists(self.history_file):
63
+ try:
64
+ with open(self.history_file, 'rb') as f:
65
+ self.chat_history = pickle.load(f)
66
+ print(f"Loaded {len(self.chat_history)} previous conversations")
67
+ except Exception as e:
68
+ print(f"Could not load chat history: {e}")
69
+ self.chat_history = []
70
+ else:
71
+ self.chat_history = []
72
+
73
+ def _save_chat_history(self):
74
+ """Save chat history to file"""
75
+ try:
76
+ with open(self.history_file, 'wb') as f:
77
+ pickle.dump(self.chat_history, f)
78
+ except Exception as e:
79
+ print(f"Could not save chat history: {e}")
80
+
81
+ def _load_learning_stats(self):
82
+ """Load learning statistics"""
83
+ if os.path.exists(self.stats_file):
84
+ try:
85
+ with open(self.stats_file, 'rb') as f:
86
+ data = pickle.load(f)
87
+ self.query_patterns = data.get('query_patterns', Counter())
88
+ self.feedback_scores = data.get('feedback_scores', {})
89
+ print(f"Loaded learning statistics: {len(self.query_patterns)} patterns tracked")
90
+ except Exception as e:
91
+ print(f"Could not load learning stats: {e}")
92
+ self.query_patterns = Counter()
93
+ self.feedback_scores = {}
94
+ else:
95
+ self.query_patterns = Counter()
96
+ self.feedback_scores = {}
97
+
98
+ def _save_learning_stats(self):
99
+ """Save learning statistics"""
100
+ try:
101
+ with open(self.stats_file, 'wb') as f:
102
+ pickle.dump({
103
+ 'query_patterns': self.query_patterns,
104
+ 'feedback_scores': self.feedback_scores
105
+ }, f)
106
+ except Exception as e:
107
+ print(f"Could not save learning stats: {e}")
108
+
109
+ def _build_chat_history_index(self):
110
+ """Build FAISS index from chat history for semantic search"""
111
+ if len(self.chat_history) == 0:
112
+ print("No chat history to index")
113
+ return
114
+
115
+ print(f"Building semantic index for {len(self.chat_history)} past conversations...")
116
+
117
+ # Create embeddings for all past Q&A pairs
118
+ chat_texts = []
119
+ for entry in self.chat_history:
120
+ # Combine question and answer for better context
121
+ combined_text = f"Q: {entry['question']}\nA: {entry['answer']}"
122
+ chat_texts.append(combined_text)
123
+
124
+ # Generate embeddings
125
+ self.chat_embeddings = self.embeddings_model.encode(chat_texts, show_progress_bar=True)
126
+
127
+ # Build FAISS index
128
+ dimension = self.chat_embeddings.shape[1]
129
+ self.chat_index = faiss.IndexFlatL2(dimension)
130
+ self.chat_index.add(np.array(self.chat_embeddings).astype('float32'))
131
+
132
+ # Save embeddings
133
+ try:
134
+ with open(self.chat_embedding_file, 'wb') as f:
135
+ pickle.dump(self.chat_embeddings, f)
136
+ except Exception as e:
137
+ print(f"Could not save chat embeddings: {e}")
138
+
139
+ print(f"Chat history index built successfully")
140
+
141
+ def _search_chat_history(self, query: str, k: int = 5) -> List[Dict]:
142
+ """Search through past conversations semantically"""
143
+ if self.chat_index is None or len(self.chat_history) == 0:
144
+ return []
145
+
146
+ # Encode query
147
+ query_embedding = self.embeddings_model.encode([query])
148
+
149
+ # Search
150
+ distances, indices = self.chat_index.search(
151
+ np.array(query_embedding).astype('float32'),
152
+ min(k, len(self.chat_history))
153
+ )
154
+
155
+ # Return relevant past conversations
156
+ relevant_chats = []
157
+ for idx, distance in zip(indices[0], distances[0]):
158
+ if distance < 1.5: # Similarity threshold
159
+ relevant_chats.append({
160
+ 'chat': self.chat_history[idx],
161
+ 'similarity_score': float(distance)
162
+ })
163
+
164
+ return relevant_chats
165
+
166
+ def _extract_entities_from_query(self, query: str) -> Dict:
167
+ """Extract names and entities from query"""
168
+ query_lower = query.lower()
169
+
170
+ # Check for pronouns that need context
171
+ has_pronoun = bool(re.search(r'\b(his|her|their|he|she|they|him|them)\b', query_lower))
172
+
173
+ # Try to extract names (capitalize words that might be names)
174
+ potential_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', query)
175
+
176
+ return {
177
+ 'has_pronoun': has_pronoun,
178
+ 'names': potential_names
179
+ }
180
+
181
+ def _update_conversation_context(self, question: str, answer: str):
182
+ """Update context tracking based on conversation"""
183
+ # Extract names from question
184
+ names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
185
+
186
+ # Extract names from answer
187
+ answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
188
+
189
+ # Update current employee if employee was mentioned
190
+ if 'employee' in answer.lower() or 'working' in answer.lower():
191
+ all_names = names + answer_names
192
+ if all_names:
193
+ self.conversation_context['current_employee'] = all_names[0]
194
+ # Keep last 5 mentioned entities
195
+ self.conversation_context['last_mentioned_entities'] = (
196
+ all_names[:5] if len(all_names) <= 5
197
+ else self.conversation_context['last_mentioned_entities'][-4:] + [all_names[0]]
198
+ )
199
+
200
+ def _resolve_pronouns(self, query: str) -> str:
201
+ """Replace pronouns with actual entity names from context"""
202
+ entities = self._extract_entities_from_query(query)
203
+
204
+ if entities['has_pronoun'] and self.conversation_context['current_employee']:
205
+ current_name = self.conversation_context['current_employee']
206
+
207
+ # Replace pronouns with the current employee name
208
+ query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
209
+ query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
210
+ query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
211
+ query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
212
+
213
+ return query
214
+
215
+
216
+ def _extract_query_pattern(self, query: str) -> str:
217
+ """Extract pattern from query for learning"""
218
+ query_lower = query.lower()
219
+
220
+ # Detect common patterns
221
+ patterns = []
222
+
223
+ if re.search(r'\bhow many\b', query_lower):
224
+ patterns.append('count_query')
225
+ if re.search(r'\bwho\b', query_lower):
226
+ patterns.append('who_query')
227
+ if re.search(r'\bwhat\b', query_lower):
228
+ patterns.append('what_query')
229
+ if re.search(r'\bwhen\b', query_lower):
230
+ patterns.append('when_query')
231
+ if re.search(r'\bwhere\b', query_lower):
232
+ patterns.append('where_query')
233
+ if re.search(r'\blist\b|\ball\b', query_lower):
234
+ patterns.append('list_query')
235
+ if re.search(r'\bcalculate\b|\bsum\b|\btotal\b|\baverage\b', query_lower):
236
+ patterns.append('calculation_query')
237
+ if re.search(r'\bemployee\b|\bstaff\b|\bworker\b', query_lower):
238
+ patterns.append('employee_query')
239
+ if re.search(r'\bpolicy\b|\brule\b|\bguideline\b', query_lower):
240
+ patterns.append('policy_query')
241
+
242
+ return '|'.join(patterns) if patterns else 'general_query'
243
+
244
+ def _load_pdf_text(self) -> str:
245
+ """Load text from PDF"""
246
+ text = ""
247
+ with open(self.pdf_path, 'rb') as file:
248
+ pdf_reader = PyPDF2.PdfReader(file)
249
+ for page in pdf_reader.pages:
250
+ text += page.extract_text()
251
+ return text
252
+
253
+ def _extract_and_merge_tables(self) -> str:
254
+ """Extract all tables from PDF and merge into single CSV"""
255
+ try:
256
+ print("Extracting tables from PDF...")
257
+
258
+ # Extract all tables
259
+ dfs = tabula.read_pdf(self.pdf_path, pages="all", multiple_tables=True)
260
+
261
+ if not dfs or len(dfs) == 0:
262
+ print("No tables found in PDF")
263
+ return None
264
+
265
+ print(f"Found {len(dfs)} tables")
266
+
267
+ # The first table has headers
268
+ merged_df = dfs[0]
269
+
270
+ # Append rest of the tables
271
+ for i in range(1, len(dfs)):
272
+ # Set the column names to match the first table
273
+ dfs[i].columns = merged_df.columns
274
+ # Append rows
275
+ merged_df = pd.concat([merged_df, dfs[i]], ignore_index=True)
276
+
277
+ # Save merged table
278
+ csv_path = os.path.join(self.output_dir, "merged_employee_tables.csv")
279
+ merged_df.to_csv(csv_path, index=False)
280
+
281
+ print(f"Merged {len(dfs)} tables into {csv_path}")
282
+ print(f"Total rows: {len(merged_df)}")
283
+ print(f"Columns: {list(merged_df.columns)}")
284
+
285
+ return csv_path
286
+
287
+ except Exception as e:
288
+ print(f"Table extraction failed: {e}")
289
+ return None
290
+
291
+ def _save_table_chunks(self, table_chunks: List[Dict]) -> str:
292
+ """Save table chunks (full table + row chunks) to a text file"""
293
+ save_path = os.path.join(self.output_dir, "table_chunks.txt")
294
+
295
+ with open(save_path, 'w', encoding='utf-8') as f:
296
+ f.write(f"Total Table Chunks: {len(table_chunks)}\n")
297
+ f.write("=" * 80 + "\n\n")
298
+
299
+ for i, chunk in enumerate(table_chunks):
300
+ f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
301
+ f.write("-" * 80 + "\n")
302
+ f.write(chunk['content'])
303
+ f.write("\n\n" + "=" * 80 + "\n\n")
304
+
305
+ print(f"Saved {len(table_chunks)} table chunks to {save_path}")
306
+ return save_path
307
+
308
+ def _detect_table_regions_in_text(self, text: str) -> List[Tuple[int, int]]:
309
+ """Detect start and end positions of table regions in text"""
310
+ lines = text.split('\n')
311
+ table_regions = []
312
+ start_idx = None
313
+
314
+ for i, line in enumerate(lines):
315
+ is_table_line = (
316
+ '@' in line or
317
+ re.search(r'\b(A|B|AB|O)[+-]\b', line) or
318
+ re.search(r'\s{3,}', line) or
319
+ re.search(r'Employee Name|Email|Position|Table|Blood Group', line, re.IGNORECASE)
320
+ )
321
+
322
+ if is_table_line:
323
+ if start_idx is None:
324
+ start_idx = i
325
+ else:
326
+ if start_idx is not None:
327
+ # End of table region
328
+ if i - start_idx > 3: # Only consider tables with 3+ lines
329
+ table_regions.append((start_idx, i))
330
+ start_idx = None
331
+
332
+ # Handle last table if exists
333
+ if start_idx is not None and len(lines) - start_idx > 3:
334
+ table_regions.append((start_idx, len(lines)))
335
+
336
+ return table_regions
337
+
338
+ def _remove_table_text(self, text: str) -> str:
339
+ """Remove table content from text"""
340
+ lines = text.split('\n')
341
+ table_regions = self._detect_table_regions_in_text(text)
342
+
343
+ # Create set of line indices to remove
344
+ lines_to_remove = set()
345
+ for start, end in table_regions:
346
+ for i in range(start, end):
347
+ lines_to_remove.add(i)
348
+
349
+ # Keep only non-table lines
350
+ clean_lines = [line for i, line in enumerate(lines) if i not in lines_to_remove]
351
+
352
+ return '\n'.join(clean_lines)
353
+
354
+ def _chunk_text_content(self, text: str) -> List[Dict]:
355
+ """Chunk text content (Q&A pairs and other text)"""
356
+ chunks = []
357
+
358
+ # Remove table text
359
+ clean_text = self._remove_table_text(text)
360
+
361
+ # Split by ###Question###
362
+ qa_pairs = clean_text.split('###Question###')
363
+
364
+ for i, qa in enumerate(qa_pairs):
365
+ if not qa.strip():
366
+ continue
367
+
368
+ if '###Answer###' in qa:
369
+ chunk_text = '###Question###' + qa
370
+ if len(chunk_text) > 50:
371
+ chunks.append({
372
+ 'content': chunk_text,
373
+ 'type': 'qa',
374
+ 'source': 'text_content',
375
+ 'chunk_id': f'qa_{i}'
376
+ })
377
+
378
+ # Also create chunks from sections (for non-Q&A content)
379
+ sections = re.split(r'\n\n+', clean_text)
380
+ for i, section in enumerate(sections):
381
+ section = section.strip()
382
+ if len(section) > 200 and '###Question###' not in section:
383
+ chunks.append({
384
+ 'content': section,
385
+ 'type': 'text',
386
+ 'source': 'text_content',
387
+ 'chunk_id': f'text_{i}'
388
+ })
389
+
390
+ return chunks
391
+
392
+ def _save_text_chunks(self, chunks: List[Dict]) -> str:
393
+ """Save text chunks to file"""
394
+ text_path = os.path.join(self.output_dir, "text_chunks.txt")
395
+
396
+ with open(text_path, 'w', encoding='utf-8') as f:
397
+ f.write(f"Total Text Chunks: {len(chunks)}\n")
398
+ f.write("=" * 80 + "\n\n")
399
+
400
+ for i, chunk in enumerate(chunks):
401
+ f.write(f"CHUNK {i + 1} [Type: {chunk['type']}]\n")
402
+ f.write("-" * 80 + "\n")
403
+ f.write(chunk['content'])
404
+ f.write("\n\n" + "=" * 80 + "\n\n")
405
+
406
+ print(f"Saved {len(chunks)} text chunks to {text_path}")
407
+ return text_path
408
+
409
+ def _load_csv_as_text(self, csv_path: str) -> str:
410
+ """Load CSV and convert to readable text format"""
411
+ try:
412
+ df = pd.read_csv(csv_path)
413
+ text = f"[EMPLOYEE TABLE DATA]\n"
414
+ text += f"Total Employees: {len(df)}\n\n"
415
+ text += df.to_string(index=False)
416
+ return text
417
+ except Exception as e:
418
+ print(f"Error loading CSV: {e}")
419
+ return ""
420
+
421
+ def _create_table_chunks(self, csv_path: str) -> List[Dict]:
422
+ """Create chunks from CSV table"""
423
+ chunks = []
424
+
425
+ try:
426
+ df = pd.read_csv(csv_path)
427
+
428
+ # Create one chunk with full table overview
429
+ full_table_text = f"[COMPLETE EMPLOYEE TABLE]\n"
430
+ full_table_text += f"Total Employees: {len(df)}\n"
431
+ full_table_text += f"Columns: {', '.join(df.columns)}\n\n"
432
+ full_table_text += df.to_string(index=False)
433
+
434
+ chunks.append({
435
+ 'content': full_table_text,
436
+ 'type': 'table_full',
437
+ 'source': 'employee_table.csv',
438
+ 'chunk_id': 'table_full'
439
+ })
440
+
441
+ # Create chunks for each row (employee)
442
+ for idx, row in df.iterrows():
443
+ row_text = f"[EMPLOYEE RECORD {idx + 1}]\n"
444
+ for col in df.columns:
445
+ row_text += f"{col}: {row[col]}\n"
446
+
447
+ chunks.append({
448
+ 'content': row_text,
449
+ 'type': 'table_row',
450
+ 'source': 'employee_table.csv',
451
+ 'chunk_id': f'employee_{idx}'
452
+ })
453
+
454
+ print(f"Created {len(chunks)} chunks from table ({len(df)} employee records + 1 full table)")
455
+
456
+ except Exception as e:
457
+ print(f"Error creating table chunks: {e}")
458
+
459
+ return chunks
460
+
461
+ def _save_manifest(self, all_chunks: List[Dict]):
462
+ """Save manifest of all chunks"""
463
+ manifest = {
464
+ 'total_chunks': len(all_chunks),
465
+ 'chunks_by_type': {
466
+ 'qa': sum(1 for c in all_chunks if c['type'] == 'qa'),
467
+ 'text': sum(1 for c in all_chunks if c['type'] == 'text'),
468
+ 'table_full': sum(1 for c in all_chunks if c['type'] == 'table_full'),
469
+ 'table_row': sum(1 for c in all_chunks if c['type'] == 'table_row')
470
+ },
471
+ 'files_created': {
472
+ 'table_csv': self.table_csv_path,
473
+ 'text_chunks': self.text_chunks_path
474
+ },
475
+ 'chunk_details': [
476
+ {
477
+ 'chunk_id': c['chunk_id'],
478
+ 'type': c['type'],
479
+ 'source': c['source'],
480
+ 'length': len(c['content'])
481
+ }
482
+ for c in all_chunks
483
+ ]
484
+ }
485
+
486
+ manifest_path = os.path.join(self.output_dir, 'chunk_manifest.json')
487
+ with open(manifest_path, 'w', encoding='utf-8') as f:
488
+ json.dump(manifest, f, indent=2, ensure_ascii=False)
489
+
490
+ print(f"Saved manifest to {manifest_path}")
491
+ return manifest_path
492
+
493
+ def _resolve_pronouns_for_session(self, query: str, conversation_context: Dict) -> str:
494
+ """Resolve pronouns using session-specific context"""
495
+ entities = self._extract_entities_from_query(query)
496
+
497
+ if entities['has_pronoun'] and conversation_context.get('current_employee'):
498
+ current_name = conversation_context['current_employee']
499
+
500
+ query = re.sub(r'\bhis\b', f"{current_name}'s", query, flags=re.IGNORECASE)
501
+ query = re.sub(r'\bher\b', f"{current_name}'s", query, flags=re.IGNORECASE)
502
+ query = re.sub(r'\bhe\b', current_name, query, flags=re.IGNORECASE)
503
+ query = re.sub(r'\bshe\b', current_name, query, flags=re.IGNORECASE)
504
+
505
+ return query
506
+
507
+ def _search_session_history(self, query: str, session_history: List[Dict], k: int = 5) -> List[Dict]:
508
+ """Search through session-specific history"""
509
+ if not session_history:
510
+ return []
511
+
512
+ chat_texts = [f"Q: {entry['question']}\nA: {entry['answer']}" for entry in session_history]
513
+
514
+ if not chat_texts:
515
+ return []
516
+
517
+ chat_embeddings = self.embeddings_model.encode(chat_texts)
518
+
519
+ dimension = chat_embeddings.shape[1]
520
+ temp_index = faiss.IndexFlatL2(dimension)
521
+ temp_index.add(np.array(chat_embeddings).astype('float32'))
522
+
523
+ query_embedding = self.embeddings_model.encode([query])
524
+ distances, indices = temp_index.search(
525
+ np.array(query_embedding).astype('float32'),
526
+ min(k, len(session_history))
527
+ )
528
+
529
+ relevant_chats = []
530
+ for idx, distance in zip(indices[0], distances[0]):
531
+ if distance < 1.5:
532
+ relevant_chats.append({
533
+ 'chat': session_history[idx],
534
+ 'similarity_score': float(distance)
535
+ })
536
+
537
+ return relevant_chats
538
+
539
+ def _build_prompt_for_session(self, query: str, retrieved_data: List[Tuple[str, Dict]],
540
+ relevant_past_chats: List[Dict], session_history: List[Dict],
541
+ conversation_context: Dict) -> str:
542
+ """Build prompt using session-specific data"""
543
+
544
+ employee_records = []
545
+ full_table = []
546
+ qa_context = []
547
+ text_context = []
548
+
549
+ for content, metadata in retrieved_data:
550
+ if metadata['type'] == 'table_row':
551
+ employee_records.append(content)
552
+ elif metadata['type'] == 'table_full':
553
+ full_table.append(content)
554
+ elif metadata['type'] == 'qa':
555
+ qa_context.append(content)
556
+ elif metadata['type'] == 'text':
557
+ text_context.append(content)
558
+
559
+ context_text = ""
560
+ if full_table:
561
+ context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
562
+ if employee_records:
563
+ context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
564
+ if qa_context:
565
+ context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
566
+ if text_context:
567
+ context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
568
+
569
+ context_memory = ""
570
+ if conversation_context.get('current_employee'):
571
+ context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
572
+ context_memory += f"Currently discussing: {conversation_context['current_employee']}\n"
573
+ if conversation_context.get('last_mentioned_entities'):
574
+ context_memory += f"Recently mentioned: {', '.join(conversation_context['last_mentioned_entities'])}\n"
575
+ context_memory += "\n"
576
+
577
+ past_context = ""
578
+ if relevant_past_chats:
579
+ past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
580
+ for i, chat_info in enumerate(relevant_past_chats[:3], 1):
581
+ chat = chat_info['chat']
582
+ past_context += f"\n[Past Q&A {i}]:\n"
583
+ past_context += f"Previous Question: {chat['question']}\n"
584
+ past_context += f"Previous Answer: {chat['answer']}\n"
585
+ past_context += "\n"
586
+
587
+ history_text = ""
588
+ for entry in session_history[-10:]:
589
+ history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
590
+
591
+ prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
592
+
593
+ IMPORTANT INSTRUCTIONS:
594
+ - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
595
+ - For employee-related queries, use the employee data provided
596
+ - If you find any name from user input, always look into the EMPLOYEE TABLE first
597
+ - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in THIS USER's recent conversation
598
+ - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
599
+ - Be careful not to give all employee information - only answer what was asked
600
+ - For counting or calculations, use the table data
601
+ - For policy questions, use the Q&A knowledge base
602
+ - Provide specific, accurate answers based on the context
603
+ - If information is not in the context, say "I don't have this information"
604
+ - Round up any fractional numbers in calculations
605
+
606
+ Context:
607
+ {context_text}
608
+
609
+ {context_memory}
610
+
611
+ {past_context}
612
+
613
+ Recent conversation:
614
+ {history_text}
615
+
616
+ User Question: {query}
617
+
618
+ Answer based on the context above. Be specific and accurate.[/INST]"""
619
+
620
+ return prompt
621
+
622
+ def _update_conversation_context_for_session(self, question: str, answer: str, conversation_context: Dict):
623
+ """Update session-specific conversation context"""
624
+ names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', question)
625
+ answer_names = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
626
+
627
+ if 'employee' in answer.lower() or 'working' in answer.lower():
628
+ all_names = names + answer_names
629
+ if all_names:
630
+ conversation_context['current_employee'] = all_names[0]
631
+ conversation_context['last_mentioned_entities'] = (
632
+ all_names[:5] if len(all_names) <= 5
633
+ else conversation_context.get('last_mentioned_entities', [])[-4:] + [all_names[0]]
634
+ )
635
+
636
+ def _setup(self):
637
+ print("\n" + "=" * 80)
638
+ print("STEP 1: Loading PDF")
639
+ print("=" * 80)
640
+
641
+ text = self._load_pdf_text()
642
+ print(f"Loaded PDF with {len(text)} characters")
643
+
644
+ print("\n" + "=" * 80)
645
+ print("STEP 2: Extracting and Merging Tables")
646
+ print("=" * 80)
647
+
648
+ self.table_csv_path = self._extract_and_merge_tables()
649
+
650
+ print("\n" + "=" * 80)
651
+ print("STEP 3: Chunking Text Content (Removing Tables)")
652
+ print("=" * 80)
653
+
654
+ text_chunks = self._chunk_text_content(text)
655
+ self.text_chunks_path = self._save_text_chunks(text_chunks)
656
+
657
+ print("\n" + "=" * 80)
658
+ print("STEP 4: Creating Final Chunks")
659
+ print("=" * 80)
660
+
661
+ all_chunks = []
662
+
663
+ # Add text chunks
664
+ all_chunks.extend(text_chunks)
665
+
666
+ # Add table chunks
667
+ if self.table_csv_path:
668
+ table_chunks = self._create_table_chunks(self.table_csv_path)
669
+ all_chunks.extend(table_chunks)
670
+ # Save chunked table text to file
671
+ self._save_table_chunks(table_chunks)
672
+
673
+ # Extract content and metadata
674
+ self.chunks = [c['content'] for c in all_chunks]
675
+ self.chunk_metadata = all_chunks
676
+
677
+ print(f"\nTotal chunks created: {len(self.chunks)}")
678
+ print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
679
+ print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
680
+ print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
681
+ print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
682
+
683
+ # Save manifest
684
+ self._save_manifest(all_chunks)
685
+
686
+ print("\n" + "=" * 80)
687
+ print("STEP 5: Creating Embeddings")
688
+ print("=" * 80)
689
+
690
+ print("Loading embedding model...")
691
+ self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
692
+
693
+ print("Creating embeddings for all chunks...")
694
+ embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
695
+
696
+ print("Building FAISS index...")
697
+ dimension = embeddings.shape[1]
698
+ self.index = faiss.IndexFlatL2(dimension)
699
+ self.index.add(np.array(embeddings).astype('float32'))
700
+
701
+ print("\n" + "=" * 80)
702
+ print("STEP 6: Initializing LLM")
703
+ print("=" * 80)
704
+
705
+ self.llm_client = InferenceClient(token=self.hf_token)
706
+
707
+ print("\n" + "=" * 80)
708
+ print("SETUP COMPLETE!")
709
+ print("=" * 80)
710
+ print(f"Files created in: {self.output_dir}/")
711
+ print(f" - {os.path.basename(self.table_csv_path) if self.table_csv_path else 'No table CSV'}")
712
+ print(f" - {os.path.basename(self.text_chunks_path)}")
713
+ print(f" - chunk_manifest.json")
714
+ print(f" - {os.path.basename(self.history_file)}")
715
+ print("=" * 80 + "\n")
716
+
717
+ def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
718
+ """Retrieve relevant chunks with metadata"""
719
+ query_embedding = self.embeddings_model.encode([query])
720
+ distances, indices = self.index.search(np.array(query_embedding).astype('float32'), k)
721
+
722
+ results = []
723
+ for idx in indices[0]:
724
+ results.append((self.chunks[idx], self.chunk_metadata[idx]))
725
+
726
+ return results
727
+
728
+ def _build_prompt(self, query: str, retrieved_data: List[Tuple[str, Dict]], relevant_past_chats: List[Dict]) -> str:
729
+ """Build prompt with retrieved context and learned information from past chats"""
730
+
731
+ # Separate different types of context
732
+ employee_records = []
733
+ full_table = []
734
+ qa_context = []
735
+ text_context = []
736
+
737
+ for content, metadata in retrieved_data:
738
+ if metadata['type'] == 'table_row':
739
+ employee_records.append(content)
740
+ elif metadata['type'] == 'table_full':
741
+ full_table.append(content)
742
+ elif metadata['type'] == 'qa':
743
+ qa_context.append(content)
744
+ elif metadata['type'] == 'text':
745
+ text_context.append(content)
746
+
747
+ # Build context sections
748
+ context_text = ""
749
+
750
+ if full_table:
751
+ context_text += "COMPLETE EMPLOYEE TABLE:\n" + "\n".join(full_table) + "\n\n"
752
+
753
+ if employee_records:
754
+ context_text += "RELEVANT EMPLOYEE RECORDS:\n" + "\n\n".join(employee_records[:15]) + "\n\n"
755
+
756
+ if qa_context:
757
+ context_text += "COMPANY POLICIES & Q&A:\n" + "\n\n".join(qa_context) + "\n\n"
758
+
759
+ if text_context:
760
+ context_text += "ADDITIONAL INFORMATION:\n" + "\n\n".join(text_context)
761
+
762
+ # ADD THIS NEW SECTION:
763
+ context_memory = ""
764
+ if self.conversation_context['current_employee']:
765
+ context_memory = f"\nCURRENT CONVERSATION CONTEXT:\n"
766
+ context_memory += f"Currently discussing: {self.conversation_context['current_employee']}\n"
767
+ if self.conversation_context['last_mentioned_entities']:
768
+ context_memory += f"Recently mentioned: {', '.join(self.conversation_context['last_mentioned_entities'])}\n"
769
+ context_memory += "\n"
770
+
771
+ # Build relevant past conversations (learning from history)
772
+ past_context = ""
773
+ if relevant_past_chats:
774
+ past_context += "RELEVANT PAST CONVERSATIONS (for context):\n"
775
+ for i, chat_info in enumerate(relevant_past_chats[:3], 1):
776
+ chat = chat_info['chat']
777
+ past_context += f"\n[Past Q&A {i}]:\n"
778
+ past_context += f"Previous Question: {chat['question']}\n"
779
+ past_context += f"Previous Answer: {chat['answer']}\n"
780
+ past_context += "\n"
781
+
782
+ # CHANGE THIS LINE from [-3:] to [-10:]:
783
+ history_text = ""
784
+ for entry in self.chat_history: # Changed from -3 to -10
785
+ history_text += f"User: {entry['question']}\nAssistant: {entry['answer']}\n\n"
786
+
787
+ prompt = f"""<s>[INST] You are a helpful HR assistant for Acme AI Ltd. Use the provided context to answer questions accurately.
788
+
789
+ IMPORTANT INSTRUCTIONS:
790
+ - You have access to COMPLETE EMPLOYEE TABLE and individual employee records
791
+ - For employee-related queries, use the employee data provided
792
+ - If you find any name from user input, always look into the EMPLOYEE TABLE first. If you still can't find, then you can go for chunked text.
793
+ - PAY ATTENTION to pronouns (his, her, he, she) - they refer to people mentioned in recent conversation
794
+ - When user asks about "his email" or "her position", look at the conversation context to understand who they're referring to
795
+ - While your answer is related to an employee, be careful of not giving all the information of the employee. Just give the information user asked.
796
+ - For counting or calculations, use the table data
797
+ - For policy questions, use the Q&A knowledge base
798
+ - LEARN from relevant past conversations - if similar questions were asked before, maintain consistency
799
+ - Use patterns from past interactions to improve answer quality
800
+ - Provide specific, accurate answers based on the context
801
+ - If you need to count employees or perform calculations, do it carefully from the data
802
+ - If information is not in the context, just say "I don't have this information in the provided documents"
803
+ - While performing any type of mathematical calculation, always round up any fractional number.
804
+
805
+ Context:
806
+ {context_text}
807
+
808
+ {context_memory}
809
+
810
+ {past_context}
811
+
812
+ Recent conversation:
813
+ {history_text}
814
+
815
+ User Question: {query}
816
+
817
+ Answer based on the context above. Be specific and accurate. But don't always start with "based on the context"[/INST]"""
818
+
819
+ return prompt
820
+
821
+ def ask(self, question: str) -> str:
822
+ """Ask a question to the chatbot with learning from past conversations"""
823
+ if question.lower() in ["reset data", "reset"]:
824
+ self.chat_history = []
825
+ self.chat_embeddings = []
826
+ self.chat_index = None
827
+ self.conversation_context = {'current_employee': None, 'last_mentioned_entities': []} # ADD THIS LINE
828
+ self._save_chat_history()
829
+ return "Chat history has been reset."
830
+
831
+ # ADD THIS LINE:
832
+ resolved_question = self._resolve_pronouns(question)
833
+
834
+ # CHANGE 'question' to 'resolved_question' in next line:
835
+ pattern = self._extract_query_pattern(resolved_question)
836
+ self.query_patterns[pattern] += 1
837
+
838
+ # CHANGE 'question' to 'resolved_question':
839
+ relevant_past_chats = self._search_chat_history(resolved_question, k=10)
840
+
841
+ # CHANGE 'question' to 'resolved_question':
842
+ retrieved_data = self._retrieve(resolved_question, k=20)
843
+
844
+ # CHANGE 'question' to 'resolved_question':
845
+ prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
846
+
847
+ # Generate response
848
+ messages = [{"role": "user", "content": prompt}]
849
+
850
+ response = self.llm_client.chat.completions.create(
851
+ model="meta-llama/Llama-3.1-8B-Instruct",
852
+ messages=messages,
853
+ max_tokens=512,
854
+ temperature=0.3
855
+ )
856
+
857
+ answer = response.choices[0].message.content
858
+
859
+ # ADD THIS LINE:
860
+ self._update_conversation_context(question, answer)
861
+
862
+ # Store in history with timestamp and metadata
863
+ chat_entry = {
864
+ 'timestamp': datetime.now().isoformat(),
865
+ 'question': question,
866
+ 'answer': answer,
867
+ 'pattern': pattern,
868
+ 'used_past_context': len(relevant_past_chats) > 0
869
+ }
870
+
871
+ self.chat_history.append(chat_entry)
872
+
873
+ # Update chat history index with new conversation
874
+ new_text = f"Q: {question}\nA: {answer}"
875
+ new_embedding = self.embeddings_model.encode([new_text])
876
+
877
+ if self.chat_index is None:
878
+ dimension = new_embedding.shape[1]
879
+ self.chat_index = faiss.IndexFlatL2(dimension)
880
+ self.chat_embeddings = new_embedding
881
+ else:
882
+ self.chat_embeddings = np.vstack([self.chat_embeddings, new_embedding])
883
+
884
+ self.chat_index.add(np.array(new_embedding).astype('float32'))
885
+
886
+ # Save to disk after each conversation
887
+ self._save_chat_history()
888
+ self._save_learning_stats()
889
+
890
+ return answer
891
+
892
+ def provide_feedback(self, question: str, rating: int):
893
+ """Allow user to rate responses for reinforcement learning (1-5 scale)"""
894
+ if 1 <= rating <= 5:
895
+ # Find the most recent occurrence of this question
896
+ for i in range(len(self.chat_history) - 1, -1, -1):
897
+ if self.chat_history[i]['question'] == question:
898
+ chat_id = f"{i}_{self.chat_history[i]['timestamp']}"
899
+ self.feedback_scores[chat_id] = rating
900
+ self._save_learning_stats()
901
+ print(f"Feedback recorded: {rating}/5")
902
+ return
903
+ print("Question not found in recent history")
904
+ else:
905
+ print("Rating must be between 1 and 5")
906
+
907
+ def get_learning_insights(self) -> Dict:
908
+ """Get insights about what the chatbot has learned"""
909
+ total_conversations = len(self.chat_history)
910
+ conversations_with_past_context = sum(
911
+ 1 for c in self.chat_history if c.get('used_past_context', False)
912
+ )
913
+
914
+ avg_feedback = 0
915
+ if self.feedback_scores:
916
+ avg_feedback = sum(self.feedback_scores.values()) / len(self.feedback_scores)
917
+
918
+ return {
919
+ 'total_conversations': total_conversations,
920
+ 'conversations_using_past_context': conversations_with_past_context,
921
+ 'query_patterns': dict(self.query_patterns.most_common(10)),
922
+ 'total_feedback_entries': len(self.feedback_scores),
923
+ 'average_feedback_score': round(avg_feedback, 2)
924
+ }
925
+
926
+ def get_history(self) -> List[Dict]:
927
+ """Get chat history"""
928
+ return self.chat_history
929
+
930
+ def display_stats(self):
931
+ """Display system statistics"""
932
+ qa_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'qa')
933
+ text_chunks = sum(1 for c in self.chunk_metadata if c['type'] == 'text')
934
+ table_full = sum(1 for c in self.chunk_metadata if c['type'] == 'table_full')
935
+ table_rows = sum(1 for c in self.chunk_metadata if c['type'] == 'table_row')
936
+
937
+ insights = self.get_learning_insights()
938
+
939
+ print(f"\n{'=' * 80}")
940
+ print("CHATBOT STATISTICS")
941
+ print(f"{'=' * 80}")
942
+ print(f"Total chunks: {len(self.chunks)}")
943
+ print(f" - Q&A chunks: {qa_chunks}")
944
+ print(f" - Text chunks: {text_chunks}")
945
+ print(f" - Full table: {table_full}")
946
+ print(f" - Employee records: {table_rows}")
947
+ print(f"\nLEARNING STATISTICS:")
948
+ print(f" - Total conversations: {insights['total_conversations']}")
949
+ print(f" - Conversations using past context: {insights['conversations_using_past_context']}")
950
+ print(f" - Total feedback entries: {insights['total_feedback_entries']}")
951
+ print(f" - Average feedback score: {insights['average_feedback_score']}/5")
952
+ print(f"\nTop query patterns:")
953
+ for pattern, count in list(insights['query_patterns'].items())[:5]:
954
+ print(f" - {pattern}: {count}")
955
+ print(f"\nOutput directory: {self.output_dir}/")
956
+ print(f"Table CSV: {os.path.basename(self.table_csv_path) if self.table_csv_path else 'None'}")
957
+ print(f"Text chunks: {os.path.basename(self.text_chunks_path)}")
958
+ print(f"History file: {os.path.basename(self.history_file)}")
959
+ print(f"Learning stats: {os.path.basename(self.stats_file)}")
960
+ print(f"{'=' * 80}\n")
961
+
962
+
963
+ # Main execution
964
+ if __name__ == "__main__":
965
+ # Configuration
966
+ PDF_PATH = "data/policies.pdf"
967
+ HF_TOKEN = os.getenv("HF_TOKEN")
968
+
969
+ if not HF_TOKEN:
970
+ raise ValueError("HF_TOKEN environment variable not set")
971
+
972
+ # Initialize chatbot
973
+ print("\nInitializing RAG Chatbot with Learning Capabilities...")
974
+ bot = RAGChatbot(PDF_PATH, HF_TOKEN)
975
+
976
+ # Display statistics
977
+ bot.display_stats()
978
+
979
+ # Chat loop
980
+ print("Chatbot ready! Type 'exit' to quit, 'stats' for learning insights, or 'feedback' to rate last answer.\n")
981
+ last_question = None
982
+
983
+ while True:
984
+ user_input = input("You: ")
985
+
986
+ if user_input.lower() in ['exit', 'quit', 'q']:
987
+ print("Goodbye!")
988
+ break
989
+
990
+ if user_input.lower() == 'stats':
991
+ insights = bot.get_learning_insights()
992
+ print("\nLearning Insights:")
993
+ print(json.dumps(insights, indent=2))
994
+ continue
995
+
996
+ if user_input.lower() == 'feedback':
997
+ if last_question:
998
+ try:
999
+ rating = int(input("Rate the last answer (1-5): "))
1000
+ bot.provide_feedback(last_question, rating)
1001
+ except ValueError:
1002
+ print("Invalid rating")
1003
+ else:
1004
+ print("No previous question to rate")
1005
+ continue
1006
+
1007
+ if not user_input.strip():
1008
+ continue
1009
+
1010
+ try:
1011
+ last_question = user_input
1012
+ answer = bot.ask(user_input)
1013
+ print(f"\nBot: {answer}\n")
1014
+ except Exception as e:
1015
  print(f"Error: {e}\n")