Spaces:
Sleeping
Sleeping
Update chatbot.py
Browse files- chatbot.py +68 -60
chatbot.py
CHANGED
|
@@ -13,6 +13,7 @@ import os
|
|
| 13 |
import pickle
|
| 14 |
from datetime import datetime
|
| 15 |
from collections import Counter
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class RAGChatbot:
|
|
@@ -23,38 +24,38 @@ class RAGChatbot:
|
|
| 23 |
self.chunk_metadata = []
|
| 24 |
self.index = None
|
| 25 |
self.embeddings_model = None
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
self.chat_history = []
|
| 28 |
self.output_dir = "./"
|
| 29 |
self.table_csv_path = None
|
| 30 |
self.text_chunks_path = None
|
| 31 |
self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
|
| 32 |
-
|
| 33 |
-
# Chat history embeddings and index
|
| 34 |
self.chat_embeddings = []
|
| 35 |
self.chat_index = None
|
| 36 |
self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
|
| 37 |
-
|
| 38 |
-
# Learning statistics
|
| 39 |
self.query_patterns = Counter()
|
| 40 |
self.feedback_scores = {}
|
| 41 |
self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
|
| 42 |
-
|
| 43 |
-
# ADD THIS NEW SECTION:
|
| 44 |
self.conversation_context = {
|
| 45 |
'current_employee': None,
|
| 46 |
'last_mentioned_entities': []
|
| 47 |
}
|
| 48 |
-
|
| 49 |
os.makedirs(self.output_dir, exist_ok=True)
|
| 50 |
-
|
| 51 |
-
# Load existing chat history and learning data
|
| 52 |
self._load_chat_history()
|
| 53 |
self._load_learning_stats()
|
| 54 |
-
|
| 55 |
self._setup()
|
| 56 |
-
|
| 57 |
-
# Build chat history index after setup
|
| 58 |
self._build_chat_history_index()
|
| 59 |
|
| 60 |
def _load_chat_history(self):
|
|
@@ -637,83 +638,73 @@ class RAGChatbot:
|
|
| 637 |
print("\n" + "=" * 80)
|
| 638 |
print("STEP 1: Loading PDF")
|
| 639 |
print("=" * 80)
|
| 640 |
-
|
| 641 |
text = self._load_pdf_text()
|
| 642 |
print(f"Loaded PDF with {len(text)} characters")
|
| 643 |
-
|
| 644 |
print("\n" + "=" * 80)
|
| 645 |
print("STEP 2: Extracting and Merging Tables")
|
| 646 |
print("=" * 80)
|
| 647 |
-
|
| 648 |
self.table_csv_path = self._extract_and_merge_tables()
|
| 649 |
-
|
| 650 |
print("\n" + "=" * 80)
|
| 651 |
print("STEP 3: Chunking Text Content (Removing Tables)")
|
| 652 |
print("=" * 80)
|
| 653 |
-
|
| 654 |
text_chunks = self._chunk_text_content(text)
|
| 655 |
self.text_chunks_path = self._save_text_chunks(text_chunks)
|
| 656 |
-
|
| 657 |
print("\n" + "=" * 80)
|
| 658 |
print("STEP 4: Creating Final Chunks")
|
| 659 |
print("=" * 80)
|
| 660 |
-
|
| 661 |
all_chunks = []
|
| 662 |
-
|
| 663 |
-
# Add text chunks
|
| 664 |
all_chunks.extend(text_chunks)
|
| 665 |
-
|
| 666 |
-
# Add table chunks
|
| 667 |
if self.table_csv_path:
|
| 668 |
table_chunks = self._create_table_chunks(self.table_csv_path)
|
| 669 |
all_chunks.extend(table_chunks)
|
| 670 |
-
# Save chunked table text to file
|
| 671 |
self._save_table_chunks(table_chunks)
|
| 672 |
-
|
| 673 |
-
# Extract content and metadata
|
| 674 |
self.chunks = [c['content'] for c in all_chunks]
|
| 675 |
self.chunk_metadata = all_chunks
|
| 676 |
-
|
| 677 |
print(f"\nTotal chunks created: {len(self.chunks)}")
|
| 678 |
print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
|
| 679 |
print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
|
| 680 |
print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
|
| 681 |
print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
|
| 682 |
-
|
| 683 |
-
# Save manifest
|
| 684 |
self._save_manifest(all_chunks)
|
| 685 |
-
|
| 686 |
print("\n" + "=" * 80)
|
| 687 |
print("STEP 5: Creating Embeddings")
|
| 688 |
print("=" * 80)
|
| 689 |
-
|
| 690 |
print("Loading embedding model...")
|
| 691 |
self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 692 |
-
|
| 693 |
print("Creating embeddings for all chunks...")
|
| 694 |
embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
|
| 695 |
-
|
| 696 |
print("Building FAISS index...")
|
| 697 |
dimension = embeddings.shape[1]
|
| 698 |
self.index = faiss.IndexFlatL2(dimension)
|
| 699 |
self.index.add(np.array(embeddings).astype('float32'))
|
| 700 |
-
|
| 701 |
print("\n" + "=" * 80)
|
| 702 |
-
print("STEP 6: Initializing LLM")
|
| 703 |
print("=" * 80)
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
|
|
|
|
|
|
| 708 |
print("\n" + "=" * 80)
|
| 709 |
print("SETUP COMPLETE!")
|
| 710 |
print("=" * 80)
|
| 711 |
-
print(f"Files created in: {self.output_dir}/")
|
| 712 |
-
print(f" - {os.path.basename(self.table_csv_path) if self.table_csv_path else 'No table CSV'}")
|
| 713 |
-
print(f" - {os.path.basename(self.text_chunks_path)}")
|
| 714 |
-
print(f" - chunk_manifest.json")
|
| 715 |
-
print(f" - {os.path.basename(self.history_file)}")
|
| 716 |
-
print("=" * 80 + "\n")
|
| 717 |
|
| 718 |
def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
|
| 719 |
"""Retrieve relevant chunks with metadata"""
|
|
@@ -820,6 +811,7 @@ class RAGChatbot:
|
|
| 820 |
return prompt
|
| 821 |
|
| 822 |
def ask(self, question: str) -> str:
|
|
|
|
| 823 |
if question.lower() in ["reset data", "reset"]:
|
| 824 |
self.chat_history = []
|
| 825 |
self.chat_embeddings = []
|
|
@@ -838,25 +830,41 @@ class RAGChatbot:
|
|
| 838 |
# Search through past conversations for similar questions
|
| 839 |
relevant_past_chats = self._search_chat_history(resolved_question, k=5)
|
| 840 |
|
| 841 |
-
# Retrieve relevant chunks
|
| 842 |
retrieved_data = self._retrieve(resolved_question, k=20)
|
| 843 |
|
| 844 |
-
# Build prompt
|
| 845 |
prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
|
| 846 |
|
| 847 |
-
# ✅
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
-
# Update conversation context
|
| 857 |
self._update_conversation_context(question, answer)
|
| 858 |
|
| 859 |
-
# Store in history
|
| 860 |
chat_entry = {
|
| 861 |
'timestamp': datetime.now().isoformat(),
|
| 862 |
'question': question,
|
|
@@ -867,7 +875,7 @@ class RAGChatbot:
|
|
| 867 |
|
| 868 |
self.chat_history.append(chat_entry)
|
| 869 |
|
| 870 |
-
# Update chat history index
|
| 871 |
new_text = f"Q: {question}\nA: {answer}"
|
| 872 |
new_embedding = self.embeddings_model.encode([new_text])
|
| 873 |
|
|
@@ -880,7 +888,7 @@ class RAGChatbot:
|
|
| 880 |
|
| 881 |
self.chat_index.add(np.array(new_embedding).astype('float32'))
|
| 882 |
|
| 883 |
-
# Save to disk
|
| 884 |
self._save_chat_history()
|
| 885 |
self._save_learning_stats()
|
| 886 |
|
|
|
|
| 13 |
import pickle
|
| 14 |
from datetime import datetime
|
| 15 |
from collections import Counter
|
| 16 |
+
import requests
|
| 17 |
|
| 18 |
|
| 19 |
class RAGChatbot:
|
|
|
|
| 24 |
self.chunk_metadata = []
|
| 25 |
self.index = None
|
| 26 |
self.embeddings_model = None
|
| 27 |
+
|
| 28 |
+
# ✅ NEW: API configuration
|
| 29 |
+
self.api_url = "https://router.huggingface.co/v1/chat/completions"
|
| 30 |
+
self.headers = {"Authorization": f"Bearer {hf_token}"}
|
| 31 |
+
self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 32 |
+
|
| 33 |
self.chat_history = []
|
| 34 |
self.output_dir = "./"
|
| 35 |
self.table_csv_path = None
|
| 36 |
self.text_chunks_path = None
|
| 37 |
self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
|
| 38 |
+
|
|
|
|
| 39 |
self.chat_embeddings = []
|
| 40 |
self.chat_index = None
|
| 41 |
self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
|
| 42 |
+
|
|
|
|
| 43 |
self.query_patterns = Counter()
|
| 44 |
self.feedback_scores = {}
|
| 45 |
self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
|
| 46 |
+
|
|
|
|
| 47 |
self.conversation_context = {
|
| 48 |
'current_employee': None,
|
| 49 |
'last_mentioned_entities': []
|
| 50 |
}
|
| 51 |
+
|
| 52 |
os.makedirs(self.output_dir, exist_ok=True)
|
| 53 |
+
|
|
|
|
| 54 |
self._load_chat_history()
|
| 55 |
self._load_learning_stats()
|
| 56 |
+
|
| 57 |
self._setup()
|
| 58 |
+
|
|
|
|
| 59 |
self._build_chat_history_index()
|
| 60 |
|
| 61 |
def _load_chat_history(self):
|
|
|
|
| 638 |
print("\n" + "=" * 80)
|
| 639 |
print("STEP 1: Loading PDF")
|
| 640 |
print("=" * 80)
|
| 641 |
+
|
| 642 |
text = self._load_pdf_text()
|
| 643 |
print(f"Loaded PDF with {len(text)} characters")
|
| 644 |
+
|
| 645 |
print("\n" + "=" * 80)
|
| 646 |
print("STEP 2: Extracting and Merging Tables")
|
| 647 |
print("=" * 80)
|
| 648 |
+
|
| 649 |
self.table_csv_path = self._extract_and_merge_tables()
|
| 650 |
+
|
| 651 |
print("\n" + "=" * 80)
|
| 652 |
print("STEP 3: Chunking Text Content (Removing Tables)")
|
| 653 |
print("=" * 80)
|
| 654 |
+
|
| 655 |
text_chunks = self._chunk_text_content(text)
|
| 656 |
self.text_chunks_path = self._save_text_chunks(text_chunks)
|
| 657 |
+
|
| 658 |
print("\n" + "=" * 80)
|
| 659 |
print("STEP 4: Creating Final Chunks")
|
| 660 |
print("=" * 80)
|
| 661 |
+
|
| 662 |
all_chunks = []
|
|
|
|
|
|
|
| 663 |
all_chunks.extend(text_chunks)
|
| 664 |
+
|
|
|
|
| 665 |
if self.table_csv_path:
|
| 666 |
table_chunks = self._create_table_chunks(self.table_csv_path)
|
| 667 |
all_chunks.extend(table_chunks)
|
|
|
|
| 668 |
self._save_table_chunks(table_chunks)
|
| 669 |
+
|
|
|
|
| 670 |
self.chunks = [c['content'] for c in all_chunks]
|
| 671 |
self.chunk_metadata = all_chunks
|
| 672 |
+
|
| 673 |
print(f"\nTotal chunks created: {len(self.chunks)}")
|
| 674 |
print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
|
| 675 |
print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
|
| 676 |
print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
|
| 677 |
print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
|
| 678 |
+
|
|
|
|
| 679 |
self._save_manifest(all_chunks)
|
| 680 |
+
|
| 681 |
print("\n" + "=" * 80)
|
| 682 |
print("STEP 5: Creating Embeddings")
|
| 683 |
print("=" * 80)
|
| 684 |
+
|
| 685 |
print("Loading embedding model...")
|
| 686 |
self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 687 |
+
|
| 688 |
print("Creating embeddings for all chunks...")
|
| 689 |
embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
|
| 690 |
+
|
| 691 |
print("Building FAISS index...")
|
| 692 |
dimension = embeddings.shape[1]
|
| 693 |
self.index = faiss.IndexFlatL2(dimension)
|
| 694 |
self.index.add(np.array(embeddings).astype('float32'))
|
| 695 |
+
|
| 696 |
print("\n" + "=" * 80)
|
| 697 |
+
print("STEP 6: Initializing LLM API")
|
| 698 |
print("=" * 80)
|
| 699 |
+
|
| 700 |
+
# ✅ API already configured in __init__
|
| 701 |
+
print(f"API URL: {self.api_url}")
|
| 702 |
+
print(f"Model: {self.model_name}")
|
| 703 |
+
print("LLM API ready!")
|
| 704 |
+
|
| 705 |
print("\n" + "=" * 80)
|
| 706 |
print("SETUP COMPLETE!")
|
| 707 |
print("=" * 80)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
|
| 709 |
def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
|
| 710 |
"""Retrieve relevant chunks with metadata"""
|
|
|
|
| 811 |
return prompt
|
| 812 |
|
| 813 |
def ask(self, question: str) -> str:
|
| 814 |
+
"""Ask a question to the chatbot with learning from past conversations"""
|
| 815 |
if question.lower() in ["reset data", "reset"]:
|
| 816 |
self.chat_history = []
|
| 817 |
self.chat_embeddings = []
|
|
|
|
| 830 |
# Search through past conversations for similar questions
|
| 831 |
relevant_past_chats = self._search_chat_history(resolved_question, k=5)
|
| 832 |
|
| 833 |
+
# Retrieve relevant chunks
|
| 834 |
retrieved_data = self._retrieve(resolved_question, k=20)
|
| 835 |
|
| 836 |
+
# Build prompt
|
| 837 |
prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
|
| 838 |
|
| 839 |
+
# ✅ NEW: Call Hugging Face Router API
|
| 840 |
+
payload = {
|
| 841 |
+
"model": self.model_name,
|
| 842 |
+
"messages": [
|
| 843 |
+
{
|
| 844 |
+
"role": "user",
|
| 845 |
+
"content": prompt
|
| 846 |
+
}
|
| 847 |
+
],
|
| 848 |
+
"max_tokens": 512,
|
| 849 |
+
"temperature": 0.3
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
try:
|
| 853 |
+
response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=60)
|
| 854 |
+
response.raise_for_status()
|
| 855 |
+
result = response.json()
|
| 856 |
+
|
| 857 |
+
# Extract answer from response
|
| 858 |
+
answer = result["choices"][0]["message"]["content"]
|
| 859 |
+
|
| 860 |
+
except Exception as e:
|
| 861 |
+
print(f"Error calling LLM API: {e}")
|
| 862 |
+
answer = "I apologize, but I'm having trouble generating a response right now. Please try again."
|
| 863 |
|
| 864 |
+
# Update conversation context
|
| 865 |
self._update_conversation_context(question, answer)
|
| 866 |
|
| 867 |
+
# Store in history
|
| 868 |
chat_entry = {
|
| 869 |
'timestamp': datetime.now().isoformat(),
|
| 870 |
'question': question,
|
|
|
|
| 875 |
|
| 876 |
self.chat_history.append(chat_entry)
|
| 877 |
|
| 878 |
+
# Update chat history index
|
| 879 |
new_text = f"Q: {question}\nA: {answer}"
|
| 880 |
new_embedding = self.embeddings_model.encode([new_text])
|
| 881 |
|
|
|
|
| 888 |
|
| 889 |
self.chat_index.add(np.array(new_embedding).astype('float32'))
|
| 890 |
|
| 891 |
+
# Save to disk
|
| 892 |
self._save_chat_history()
|
| 893 |
self._save_learning_stats()
|
| 894 |
|