riteshraut commited on
Commit
9683d0d
·
1 Parent(s): 3b79f3b
Files changed (1) hide show
  1. rag_processor.py +32 -86
rag_processor.py CHANGED
@@ -1,8 +1,7 @@
1
  # rag_processor.py
2
 
3
  import os
4
- from dotenv import load_dotenv
5
- from operator import itemgetter # <--- ADD THIS IMPORT
6
 
7
  # LLM
8
  from langchain_groq import ChatGroq
@@ -10,90 +9,58 @@ from langchain_groq import ChatGroq
10
  # Prompting
11
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
12
 
13
- # Chains
14
- from langchain_core.runnables import RunnableParallel, RunnablePassthrough
15
- from langchain_core.output_parsers import StrOutputParser
16
  from langchain_core.runnables.history import RunnableWithMessageHistory
 
 
 
 
 
17
 
18
  def create_rag_chain(retriever, get_session_history_func):
19
  """
20
- Creates an advanced Retrieval-Augmented Generation (RAG) chain with hybrid search,
21
- query rewriting, answer refinement, and conversational memory.
22
-
23
- Args:
24
- retriever: A configured LangChain retriever object.
25
- get_session_history_func: A function to get the chat history for a session.
26
-
27
- Returns:
28
- A LangChain runnable object representing the RAG chain with memory.
29
-
30
- Raises:
31
- ValueError: If the GROQ_API_KEY is missing.
32
  """
33
- # Load environment variables from .env file
 
34
  api_key = os.getenv("GROQ_API_KEY")
35
  if not api_key:
36
- raise ValueError("GROQ_API_KEY not found in environment variables.")
37
-
38
- # --- 1. Initialize the LLM ---
39
- # Updated model_name to a standard, high-performance Groq model
40
- llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=1)
41
 
42
- # --- 2. Create Query Rewriting Chain 🧠 ---
43
- print("\nSetting up query rewriting chain...")
44
- rewrite_template = """You are an expert at rewriting user questions for a vector database.
45
- You are here to help the user with their document.
46
- Based on the chat history, reformulate the follow-up question to be a standalone question.
47
- This new query should be optimized to find the most relevant documents in a knowledge base.
48
- Do NOT answer the question, only provide the rewritten, optimized question.
49
 
50
- Chat History:
51
- {chat_history}
52
-
53
- Follow-up Question: {question}
54
- Standalone Question:"""
55
- rewrite_prompt = ChatPromptTemplate.from_messages([
56
- ("system", rewrite_template),
57
- MessagesPlaceholder(variable_name="chat_history"),
58
- ("human", "Based on our conversation, reformulate this question to be a standalone query: {question}")
59
- ])
60
- query_rewriter = rewrite_prompt | llm | StrOutputParser()
61
-
62
- # --- 3. Create Main RAG Chain with Memory ---
63
- print("\nSetting up main RAG chain...")
64
- rag_template = """You are an expert assistant named `Cognichat`.Whenver user ask you about who you are , simply say you are `Cognichat`.
65
- You are developed by Ritesh and Alish.
66
- Your job is to provide accurate and helpful answers based ONLY on the provided context.
67
- If the information is not in the context, clearly state that you don't know the answer.
68
- Provide a clear and concise answer.
69
 
70
  Context:
71
- {context}"""
 
72
  rag_prompt = ChatPromptTemplate.from_messages([
73
  ("system", rag_template),
74
  MessagesPlaceholder(variable_name="chat_history"),
75
  ("human", "{question}"),
76
  ])
77
 
78
- # ============================ FIX IS HERE ============================
79
- # Parallel process to fetch context and correctly pass through question and history.
80
- # We use itemgetter to select the specific keys from the input dictionary.
81
- setup_and_retrieval = RunnableParallel({
82
- "context": query_rewriter | retriever,
83
- "question": itemgetter("question"),
84
- "chat_history": itemgetter("chat_history"),
85
- })
86
- # =====================================================================
87
-
88
- # The initial RAG chain
89
  conversational_rag_chain = (
90
- setup_and_retrieval
 
 
 
 
91
  | rag_prompt
92
  | llm
93
  | StrOutputParser()
94
  )
95
 
96
- # Wrap the chain with memory management
97
  chain_with_memory = RunnableWithMessageHistory(
98
  conversational_rag_chain,
99
  get_session_history_func,
@@ -101,26 +68,5 @@ Context:
101
  history_messages_key="chat_history",
102
  )
103
 
104
- # --- 4. Create Answer Refinement Chain ✨ ---
105
- print("\nSetting up answer refinement chain...")
106
- refine_template = """You are an expert at editing and refining content.
107
- Your task is to take a given answer and improve its clarity, structure, and readability.
108
- Use formatting such as bold text, bullet points, or numbered lists where it enhances the explanation.
109
- Do not add any new information that wasn't in the original answer.
110
-
111
- Original Answer:
112
- {answer}
113
-
114
- Refined Answer:"""
115
- refine_prompt = ChatPromptTemplate.from_template(refine_template)
116
- refinement_chain = refine_prompt | llm | StrOutputParser()
117
-
118
- # --- 5. Combine Everything into the Final Chain ---
119
- # The final chain passes the output of the memory-enabled chain to the refinement chain
120
- # Note: We need to adapt the input for the refinement chain
121
- final_chain = (
122
- lambda input_dict: {"answer": chain_with_memory.invoke(input_dict, config=input_dict.get('config'))}
123
- ) | refinement_chain
124
-
125
- print("\nFinalizing the complete chain with memory...")
126
- return final_chain
 
1
  # rag_processor.py
2
 
3
  import os
4
+ from operator import itemgetter
 
5
 
6
  # LLM
7
  from langchain_groq import ChatGroq
 
9
  # Prompting
10
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
11
 
12
+ # Chains and Memory
 
 
13
  from langchain_core.runnables.history import RunnableWithMessageHistory
14
+ from langchain_core.output_parsers import StrOutputParser
15
+
16
+ def format_docs(docs):
17
+ """A helper function to format retrieved documents into a single string."""
18
+ return "\n\n".join(doc.page_content for doc in docs)
19
 
20
  def create_rag_chain(retriever, get_session_history_func):
21
  """
22
+ Creates a simplified and robust Retrieval-Augmented Generation (RAG) chain with memory.
23
+ This version uses a single, efficient call to the LLM per query.
 
 
 
 
 
 
 
 
 
 
24
  """
25
+ # --- 1. Get the API Key from Environment Secrets ---
26
+ # This correctly reads the secret you set on the Hugging Face Space.
27
  api_key = os.getenv("GROQ_API_KEY")
28
  if not api_key:
29
+ raise ValueError("GROQ_API_KEY secret not found in environment variables.")
 
 
 
 
30
 
31
+ # --- 2. Initialize the LLM ---
32
+ llm = ChatGroq(model_name="llama3-8b-8192", api_key=api_key, temperature=0.7)
 
 
 
 
 
33
 
34
+ # --- 3. Define the Conversational RAG Prompt ---
35
+ # This single prompt handles context, chat history, and the user's question.
36
+ rag_template = """You are an expert assistant named `CogniChat`, developed by Ritesh and Alish.
37
+ Your job is to provide accurate and helpful answers based ONLY on the provided context.
38
+ If the information to answer the question is not in the context, clearly state that the document does not contain the answer.
39
+ Be concise and clear in your responses. Use formatting like bold text or bullet points if it helps clarity.
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  Context:
42
+ {context}
43
+ """
44
  rag_prompt = ChatPromptTemplate.from_messages([
45
  ("system", rag_template),
46
  MessagesPlaceholder(variable_name="chat_history"),
47
  ("human", "{question}"),
48
  ])
49
 
50
+ # --- 4. Create the Core RAG Chain ---
51
+ # This is a standard and efficient way to build a RAG chain.
 
 
 
 
 
 
 
 
 
52
  conversational_rag_chain = (
53
+ {
54
+ "context": itemgetter("question") | retriever | format_docs,
55
+ "question": itemgetter("question"),
56
+ "chat_history": itemgetter("chat_history"),
57
+ }
58
  | rag_prompt
59
  | llm
60
  | StrOutputParser()
61
  )
62
 
63
+ # --- 5. Wrap the Chain with Memory Management ---
64
  chain_with_memory = RunnableWithMessageHistory(
65
  conversational_rag_chain,
66
  get_session_history_func,
 
68
  history_messages_key="chat_history",
69
  )
70
 
71
+ print("\n✅ Simplified RAG chain with memory created successfully.")
72
+ return chain_with_memory