alx-d commited on
Commit
c857530
·
verified ·
1 Parent(s): 23b48b8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +15 -6
advanced_rag.py CHANGED
@@ -34,6 +34,7 @@ import time
34
  print("Pydantic Version: ")
35
  print(pydantic.__version__)
36
  # Add Mistral imports with fallback handling
 
37
  try:
38
  from mistralai import Mistral
39
  MISTRAL_AVAILABLE = True
@@ -45,7 +46,7 @@ except ImportError:
45
  debug_print("Mistral client library not found. Install with: pip install mistralai")
46
 
47
  def debug_print(message: str):
48
- print(f"[{datetime.datetime.now().isoformat()}] {message}")
49
 
50
  def word_count(text: str) -> int:
51
  return len(text.split())
@@ -447,8 +448,7 @@ class ElevatedRagChain:
447
  if not mistral_api_key:
448
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
449
  try:
450
- from mistralai import Mistral
451
- from mistralai.exceptions import MistralException
452
  debug_print("Mistral library imported successfully")
453
  except ImportError:
454
  debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
@@ -473,8 +473,7 @@ class ElevatedRagChain:
473
  model="mistral-small-latest",
474
  messages=[{"role": "user", "content": prompt}],
475
  temperature=self.temperature,
476
- top_p=self.top_p,
477
- max_tokens=32000
478
  )
479
  return response.choices[0].message.content
480
  except Exception as e:
@@ -601,16 +600,24 @@ class ElevatedRagChain:
601
  retrievers=[self.bm25_retriever, self.faiss_retriever],
602
  weights=[self.bm25_weight, self.faiss_weight]
603
  )
 
604
  base_runnable = RunnableParallel({
605
  "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
606
  "question": RunnableLambda(self.extract_question)
607
  }) | self.capture_context
608
- # Wrap the prompt template in a RunnableLambda
 
609
  self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
 
 
610
  prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
 
611
  self.str_output_parser = StrOutputParser()
612
  debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
613
  self.llm = self.create_llm_pipeline()
 
 
 
614
  def format_response(response: str) -> str:
615
  input_tokens = count_tokens(self.context + self.prompt_template)
616
  output_tokens = count_tokens(response)
@@ -620,10 +627,12 @@ class ElevatedRagChain:
620
  formatted += f"- **Generated using:** {self.llm_choice}\n"
621
  formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
622
  return formatted
 
623
  self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
624
  debug_print("Elevated RAG chain successfully built and ready to use.")
625
 
626
 
 
627
  def get_current_context(self) -> str:
628
  base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
629
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
 
34
  print("Pydantic Version: ")
35
  print(pydantic.__version__)
36
  # Add Mistral imports with fallback handling
37
+
38
  try:
39
  from mistralai import Mistral
40
  MISTRAL_AVAILABLE = True
 
46
  debug_print("Mistral client library not found. Install with: pip install mistralai")
47
 
48
  def debug_print(message: str):
49
+ print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
50
 
51
  def word_count(text: str) -> int:
52
  return len(text.split())
 
448
  if not mistral_api_key:
449
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
450
  try:
451
+ from mistralai import Mistral
 
452
  debug_print("Mistral library imported successfully")
453
  except ImportError:
454
  debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
 
473
  model="mistral-small-latest",
474
  messages=[{"role": "user", "content": prompt}],
475
  temperature=self.temperature,
476
+ top_p=self.top_p
 
477
  )
478
  return response.choices[0].message.content
479
  except Exception as e:
 
600
  retrievers=[self.bm25_retriever, self.faiss_retriever],
601
  weights=[self.bm25_weight, self.faiss_weight]
602
  )
603
+
604
  base_runnable = RunnableParallel({
605
  "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
606
  "question": RunnableLambda(self.extract_question)
607
  }) | self.capture_context
608
+
609
+ # Ensure the prompt template is set
610
  self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
611
+ if self.rag_prompt is None:
612
+ raise ValueError("Prompt template could not be created from the given template.")
613
  prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
614
+
615
  self.str_output_parser = StrOutputParser()
616
  debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
617
  self.llm = self.create_llm_pipeline()
618
+ if self.llm is None:
619
+ raise ValueError("LLM pipeline creation failed.")
620
+
621
  def format_response(response: str) -> str:
622
  input_tokens = count_tokens(self.context + self.prompt_template)
623
  output_tokens = count_tokens(response)
 
627
  formatted += f"- **Generated using:** {self.llm_choice}\n"
628
  formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
629
  return formatted
630
+
631
  self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
632
  debug_print("Elevated RAG chain successfully built and ready to use.")
633
 
634
 
635
+
636
  def get_current_context(self) -> str:
637
  base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
638
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"