Upload folder using huggingface_hub
Browse files- advanced_rag.py +15 -6
advanced_rag.py
CHANGED
|
@@ -34,6 +34,7 @@ import time
|
|
| 34 |
print("Pydantic Version: ")
|
| 35 |
print(pydantic.__version__)
|
| 36 |
# Add Mistral imports with fallback handling
|
|
|
|
| 37 |
try:
|
| 38 |
from mistralai import Mistral
|
| 39 |
MISTRAL_AVAILABLE = True
|
|
@@ -45,7 +46,7 @@ except ImportError:
|
|
| 45 |
debug_print("Mistral client library not found. Install with: pip install mistralai")
|
| 46 |
|
| 47 |
def debug_print(message: str):
|
| 48 |
-
print(f"[{datetime.datetime.now().isoformat()}] {message}")
|
| 49 |
|
| 50 |
def word_count(text: str) -> int:
|
| 51 |
return len(text.split())
|
|
@@ -447,8 +448,7 @@ class ElevatedRagChain:
|
|
| 447 |
if not mistral_api_key:
|
| 448 |
raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
|
| 449 |
try:
|
| 450 |
-
from mistralai import Mistral
|
| 451 |
-
from mistralai.exceptions import MistralException
|
| 452 |
debug_print("Mistral library imported successfully")
|
| 453 |
except ImportError:
|
| 454 |
debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
|
|
@@ -473,8 +473,7 @@ class ElevatedRagChain:
|
|
| 473 |
model="mistral-small-latest",
|
| 474 |
messages=[{"role": "user", "content": prompt}],
|
| 475 |
temperature=self.temperature,
|
| 476 |
-
top_p=self.top_p
|
| 477 |
-
max_tokens=32000
|
| 478 |
)
|
| 479 |
return response.choices[0].message.content
|
| 480 |
except Exception as e:
|
|
@@ -601,16 +600,24 @@ class ElevatedRagChain:
|
|
| 601 |
retrievers=[self.bm25_retriever, self.faiss_retriever],
|
| 602 |
weights=[self.bm25_weight, self.faiss_weight]
|
| 603 |
)
|
|
|
|
| 604 |
base_runnable = RunnableParallel({
|
| 605 |
"context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
|
| 606 |
"question": RunnableLambda(self.extract_question)
|
| 607 |
}) | self.capture_context
|
| 608 |
-
|
|
|
|
| 609 |
self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
|
|
|
|
|
|
|
| 610 |
prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
|
|
|
|
| 611 |
self.str_output_parser = StrOutputParser()
|
| 612 |
debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
|
| 613 |
self.llm = self.create_llm_pipeline()
|
|
|
|
|
|
|
|
|
|
| 614 |
def format_response(response: str) -> str:
|
| 615 |
input_tokens = count_tokens(self.context + self.prompt_template)
|
| 616 |
output_tokens = count_tokens(response)
|
|
@@ -620,10 +627,12 @@ class ElevatedRagChain:
|
|
| 620 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
| 621 |
formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
|
| 622 |
return formatted
|
|
|
|
| 623 |
self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
|
| 624 |
debug_print("Elevated RAG chain successfully built and ready to use.")
|
| 625 |
|
| 626 |
|
|
|
|
| 627 |
def get_current_context(self) -> str:
|
| 628 |
base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
|
| 629 |
history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
|
|
|
|
| 34 |
print("Pydantic Version: ")
|
| 35 |
print(pydantic.__version__)
|
| 36 |
# Add Mistral imports with fallback handling
|
| 37 |
+
|
| 38 |
try:
|
| 39 |
from mistralai import Mistral
|
| 40 |
MISTRAL_AVAILABLE = True
|
|
|
|
| 46 |
debug_print("Mistral client library not found. Install with: pip install mistralai")
|
| 47 |
|
| 48 |
def debug_print(message: str):
|
| 49 |
+
print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
|
| 50 |
|
| 51 |
def word_count(text: str) -> int:
|
| 52 |
return len(text.split())
|
|
|
|
| 448 |
if not mistral_api_key:
|
| 449 |
raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
|
| 450 |
try:
|
| 451 |
+
from mistralai import Mistral
|
|
|
|
| 452 |
debug_print("Mistral library imported successfully")
|
| 453 |
except ImportError:
|
| 454 |
debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
|
|
|
|
| 473 |
model="mistral-small-latest",
|
| 474 |
messages=[{"role": "user", "content": prompt}],
|
| 475 |
temperature=self.temperature,
|
| 476 |
+
top_p=self.top_p
|
|
|
|
| 477 |
)
|
| 478 |
return response.choices[0].message.content
|
| 479 |
except Exception as e:
|
|
|
|
| 600 |
retrievers=[self.bm25_retriever, self.faiss_retriever],
|
| 601 |
weights=[self.bm25_weight, self.faiss_weight]
|
| 602 |
)
|
| 603 |
+
|
| 604 |
base_runnable = RunnableParallel({
|
| 605 |
"context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
|
| 606 |
"question": RunnableLambda(self.extract_question)
|
| 607 |
}) | self.capture_context
|
| 608 |
+
|
| 609 |
+
# Ensure the prompt template is set
|
| 610 |
self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
|
| 611 |
+
if self.rag_prompt is None:
|
| 612 |
+
raise ValueError("Prompt template could not be created from the given template.")
|
| 613 |
prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
|
| 614 |
+
|
| 615 |
self.str_output_parser = StrOutputParser()
|
| 616 |
debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
|
| 617 |
self.llm = self.create_llm_pipeline()
|
| 618 |
+
if self.llm is None:
|
| 619 |
+
raise ValueError("LLM pipeline creation failed.")
|
| 620 |
+
|
| 621 |
def format_response(response: str) -> str:
|
| 622 |
input_tokens = count_tokens(self.context + self.prompt_template)
|
| 623 |
output_tokens = count_tokens(response)
|
|
|
|
| 627 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
| 628 |
formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
|
| 629 |
return formatted
|
| 630 |
+
|
| 631 |
self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
|
| 632 |
debug_print("Elevated RAG chain successfully built and ready to use.")
|
| 633 |
|
| 634 |
|
| 635 |
+
|
| 636 |
def get_current_context(self) -> str:
|
| 637 |
base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
|
| 638 |
history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
|