Jaheen07 commited on
Commit
4233eaf
·
verified ·
1 Parent(s): 265d329

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +68 -60
chatbot.py CHANGED
@@ -13,6 +13,7 @@ import os
13
  import pickle
14
  from datetime import datetime
15
  from collections import Counter
 
16
 
17
 
18
  class RAGChatbot:
@@ -23,38 +24,38 @@ class RAGChatbot:
23
  self.chunk_metadata = []
24
  self.index = None
25
  self.embeddings_model = None
26
- self.llm_client = None
 
 
 
 
 
27
  self.chat_history = []
28
  self.output_dir = "./"
29
  self.table_csv_path = None
30
  self.text_chunks_path = None
31
  self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
32
-
33
- # Chat history embeddings and index
34
  self.chat_embeddings = []
35
  self.chat_index = None
36
  self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
37
-
38
- # Learning statistics
39
  self.query_patterns = Counter()
40
  self.feedback_scores = {}
41
  self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
42
-
43
- # ADD THIS NEW SECTION:
44
  self.conversation_context = {
45
  'current_employee': None,
46
  'last_mentioned_entities': []
47
  }
48
-
49
  os.makedirs(self.output_dir, exist_ok=True)
50
-
51
- # Load existing chat history and learning data
52
  self._load_chat_history()
53
  self._load_learning_stats()
54
-
55
  self._setup()
56
-
57
- # Build chat history index after setup
58
  self._build_chat_history_index()
59
 
60
  def _load_chat_history(self):
@@ -637,83 +638,73 @@ class RAGChatbot:
637
  print("\n" + "=" * 80)
638
  print("STEP 1: Loading PDF")
639
  print("=" * 80)
640
-
641
  text = self._load_pdf_text()
642
  print(f"Loaded PDF with {len(text)} characters")
643
-
644
  print("\n" + "=" * 80)
645
  print("STEP 2: Extracting and Merging Tables")
646
  print("=" * 80)
647
-
648
  self.table_csv_path = self._extract_and_merge_tables()
649
-
650
  print("\n" + "=" * 80)
651
  print("STEP 3: Chunking Text Content (Removing Tables)")
652
  print("=" * 80)
653
-
654
  text_chunks = self._chunk_text_content(text)
655
  self.text_chunks_path = self._save_text_chunks(text_chunks)
656
-
657
  print("\n" + "=" * 80)
658
  print("STEP 4: Creating Final Chunks")
659
  print("=" * 80)
660
-
661
  all_chunks = []
662
-
663
- # Add text chunks
664
  all_chunks.extend(text_chunks)
665
-
666
- # Add table chunks
667
  if self.table_csv_path:
668
  table_chunks = self._create_table_chunks(self.table_csv_path)
669
  all_chunks.extend(table_chunks)
670
- # Save chunked table text to file
671
  self._save_table_chunks(table_chunks)
672
-
673
- # Extract content and metadata
674
  self.chunks = [c['content'] for c in all_chunks]
675
  self.chunk_metadata = all_chunks
676
-
677
  print(f"\nTotal chunks created: {len(self.chunks)}")
678
  print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
679
  print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
680
  print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
681
  print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
682
-
683
- # Save manifest
684
  self._save_manifest(all_chunks)
685
-
686
  print("\n" + "=" * 80)
687
  print("STEP 5: Creating Embeddings")
688
  print("=" * 80)
689
-
690
  print("Loading embedding model...")
691
  self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
692
-
693
  print("Creating embeddings for all chunks...")
694
  embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
695
-
696
  print("Building FAISS index...")
697
  dimension = embeddings.shape[1]
698
  self.index = faiss.IndexFlatL2(dimension)
699
  self.index.add(np.array(embeddings).astype('float32'))
700
-
701
  print("\n" + "=" * 80)
702
- print("STEP 6: Initializing LLM")
703
  print("=" * 80)
704
-
705
- self.llm_client = InferenceClient(token=self.hf_token)
706
- self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
707
-
 
 
708
  print("\n" + "=" * 80)
709
  print("SETUP COMPLETE!")
710
  print("=" * 80)
711
- print(f"Files created in: {self.output_dir}/")
712
- print(f" - {os.path.basename(self.table_csv_path) if self.table_csv_path else 'No table CSV'}")
713
- print(f" - {os.path.basename(self.text_chunks_path)}")
714
- print(f" - chunk_manifest.json")
715
- print(f" - {os.path.basename(self.history_file)}")
716
- print("=" * 80 + "\n")
717
 
718
  def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
719
  """Retrieve relevant chunks with metadata"""
@@ -820,6 +811,7 @@ class RAGChatbot:
820
  return prompt
821
 
822
  def ask(self, question: str) -> str:
 
823
  if question.lower() in ["reset data", "reset"]:
824
  self.chat_history = []
825
  self.chat_embeddings = []
@@ -838,25 +830,41 @@ class RAGChatbot:
838
  # Search through past conversations for similar questions
839
  relevant_past_chats = self._search_chat_history(resolved_question, k=5)
840
 
841
- # Retrieve relevant chunks (use resolved question for better retrieval)
842
  retrieved_data = self._retrieve(resolved_question, k=20)
843
 
844
- # Build prompt with both document context and learned information
845
  prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
846
 
847
- # ✅ CORRECT: Use text_generation for InferenceClient
848
- answer = self.llm_client.text_generation(
849
- prompt,
850
- model=self.model_name,
851
- max_new_tokens=512,
852
- temperature=0.3,
853
- return_full_text=False
854
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
 
856
- # Update conversation context after each exchange
857
  self._update_conversation_context(question, answer)
858
 
859
- # Store in history with timestamp and metadata
860
  chat_entry = {
861
  'timestamp': datetime.now().isoformat(),
862
  'question': question,
@@ -867,7 +875,7 @@ class RAGChatbot:
867
 
868
  self.chat_history.append(chat_entry)
869
 
870
- # Update chat history index with new conversation
871
  new_text = f"Q: {question}\nA: {answer}"
872
  new_embedding = self.embeddings_model.encode([new_text])
873
 
@@ -880,7 +888,7 @@ class RAGChatbot:
880
 
881
  self.chat_index.add(np.array(new_embedding).astype('float32'))
882
 
883
- # Save to disk after each conversation
884
  self._save_chat_history()
885
  self._save_learning_stats()
886
 
 
13
  import pickle
14
  from datetime import datetime
15
  from collections import Counter
16
+ import requests
17
 
18
 
19
  class RAGChatbot:
 
24
  self.chunk_metadata = []
25
  self.index = None
26
  self.embeddings_model = None
27
+
28
+ # ✅ NEW: API configuration
29
+ self.api_url = "https://router.huggingface.co/v1/chat/completions"
30
+ self.headers = {"Authorization": f"Bearer {hf_token}"}
31
+ self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
32
+
33
  self.chat_history = []
34
  self.output_dir = "./"
35
  self.table_csv_path = None
36
  self.text_chunks_path = None
37
  self.history_file = os.path.join(self.output_dir, "chat_history.pkl")
38
+
 
39
  self.chat_embeddings = []
40
  self.chat_index = None
41
  self.chat_embedding_file = os.path.join(self.output_dir, "chat_embeddings.pkl")
42
+
 
43
  self.query_patterns = Counter()
44
  self.feedback_scores = {}
45
  self.stats_file = os.path.join(self.output_dir, "learning_stats.pkl")
46
+
 
47
  self.conversation_context = {
48
  'current_employee': None,
49
  'last_mentioned_entities': []
50
  }
51
+
52
  os.makedirs(self.output_dir, exist_ok=True)
53
+
 
54
  self._load_chat_history()
55
  self._load_learning_stats()
56
+
57
  self._setup()
58
+
 
59
  self._build_chat_history_index()
60
 
61
  def _load_chat_history(self):
 
638
  print("\n" + "=" * 80)
639
  print("STEP 1: Loading PDF")
640
  print("=" * 80)
641
+
642
  text = self._load_pdf_text()
643
  print(f"Loaded PDF with {len(text)} characters")
644
+
645
  print("\n" + "=" * 80)
646
  print("STEP 2: Extracting and Merging Tables")
647
  print("=" * 80)
648
+
649
  self.table_csv_path = self._extract_and_merge_tables()
650
+
651
  print("\n" + "=" * 80)
652
  print("STEP 3: Chunking Text Content (Removing Tables)")
653
  print("=" * 80)
654
+
655
  text_chunks = self._chunk_text_content(text)
656
  self.text_chunks_path = self._save_text_chunks(text_chunks)
657
+
658
  print("\n" + "=" * 80)
659
  print("STEP 4: Creating Final Chunks")
660
  print("=" * 80)
661
+
662
  all_chunks = []
 
 
663
  all_chunks.extend(text_chunks)
664
+
 
665
  if self.table_csv_path:
666
  table_chunks = self._create_table_chunks(self.table_csv_path)
667
  all_chunks.extend(table_chunks)
 
668
  self._save_table_chunks(table_chunks)
669
+
 
670
  self.chunks = [c['content'] for c in all_chunks]
671
  self.chunk_metadata = all_chunks
672
+
673
  print(f"\nTotal chunks created: {len(self.chunks)}")
674
  print(f" - Q&A chunks: {sum(1 for c in all_chunks if c['type'] == 'qa')}")
675
  print(f" - Text chunks: {sum(1 for c in all_chunks if c['type'] == 'text')}")
676
  print(f" - Table full: {sum(1 for c in all_chunks if c['type'] == 'table_full')}")
677
  print(f" - Employee records: {sum(1 for c in all_chunks if c['type'] == 'table_row')}")
678
+
 
679
  self._save_manifest(all_chunks)
680
+
681
  print("\n" + "=" * 80)
682
  print("STEP 5: Creating Embeddings")
683
  print("=" * 80)
684
+
685
  print("Loading embedding model...")
686
  self.embeddings_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
687
+
688
  print("Creating embeddings for all chunks...")
689
  embeddings = self.embeddings_model.encode(self.chunks, show_progress_bar=True)
690
+
691
  print("Building FAISS index...")
692
  dimension = embeddings.shape[1]
693
  self.index = faiss.IndexFlatL2(dimension)
694
  self.index.add(np.array(embeddings).astype('float32'))
695
+
696
  print("\n" + "=" * 80)
697
+ print("STEP 6: Initializing LLM API")
698
  print("=" * 80)
699
+
700
+ # API already configured in __init__
701
+ print(f"API URL: {self.api_url}")
702
+ print(f"Model: {self.model_name}")
703
+ print("LLM API ready!")
704
+
705
  print("\n" + "=" * 80)
706
  print("SETUP COMPLETE!")
707
  print("=" * 80)
 
 
 
 
 
 
708
 
709
  def _retrieve(self, query: str, k: int = 10) -> List[Tuple[str, Dict]]:
710
  """Retrieve relevant chunks with metadata"""
 
811
  return prompt
812
 
813
  def ask(self, question: str) -> str:
814
+ """Ask a question to the chatbot with learning from past conversations"""
815
  if question.lower() in ["reset data", "reset"]:
816
  self.chat_history = []
817
  self.chat_embeddings = []
 
830
  # Search through past conversations for similar questions
831
  relevant_past_chats = self._search_chat_history(resolved_question, k=5)
832
 
833
+ # Retrieve relevant chunks
834
  retrieved_data = self._retrieve(resolved_question, k=20)
835
 
836
+ # Build prompt
837
  prompt = self._build_prompt(resolved_question, retrieved_data, relevant_past_chats)
838
 
839
+ # ✅ NEW: Call Hugging Face Router API
840
+ payload = {
841
+ "model": self.model_name,
842
+ "messages": [
843
+ {
844
+ "role": "user",
845
+ "content": prompt
846
+ }
847
+ ],
848
+ "max_tokens": 512,
849
+ "temperature": 0.3
850
+ }
851
+
852
+ try:
853
+ response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=60)
854
+ response.raise_for_status()
855
+ result = response.json()
856
+
857
+ # Extract answer from response
858
+ answer = result["choices"][0]["message"]["content"]
859
+
860
+ except Exception as e:
861
+ print(f"Error calling LLM API: {e}")
862
+ answer = "I apologize, but I'm having trouble generating a response right now. Please try again."
863
 
864
+ # Update conversation context
865
  self._update_conversation_context(question, answer)
866
 
867
+ # Store in history
868
  chat_entry = {
869
  'timestamp': datetime.now().isoformat(),
870
  'question': question,
 
875
 
876
  self.chat_history.append(chat_entry)
877
 
878
+ # Update chat history index
879
  new_text = f"Q: {question}\nA: {answer}"
880
  new_embedding = self.embeddings_model.encode([new_text])
881
 
 
888
 
889
  self.chat_index.add(np.array(new_embedding).astype('float32'))
890
 
891
+ # Save to disk
892
  self._save_chat_history()
893
  self._save_learning_stats()
894