aman1762 commited on
Commit
26fe9e2
·
verified ·
1 Parent(s): 2c837a2

Update rag_chain.py

Browse files
Files changed (1) hide show
  1. rag_chain.py +31 -5
rag_chain.py CHANGED
@@ -1,15 +1,41 @@
1
- from langchain_community.chains import RetrievalQA
2
  from langchain_groq import ChatGroq
 
 
 
3
 
4
  def build_rag_chain(vectorstore, groq_api_key):
 
 
5
  llm = ChatGroq(
6
  api_key=groq_api_key,
7
  model="llama3-8b-8192",
8
  temperature=0
9
  )
10
 
11
- return RetrievalQA.from_chain_type(
12
- llm=llm,
13
- retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
14
- return_source_documents=True
 
 
 
 
 
 
 
15
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain_groq import ChatGroq
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_core.runnables import RunnablePassthrough
5
 
6
  def build_rag_chain(vectorstore, groq_api_key):
7
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
8
+
9
  llm = ChatGroq(
10
  api_key=groq_api_key,
11
  model="llama3-8b-8192",
12
  temperature=0
13
  )
14
 
15
+ prompt = ChatPromptTemplate.from_template(
16
+ """
17
+ You are an expert software engineer.
18
+ Answer the question using ONLY the context below.
19
+
20
+ Context:
21
+ {context}
22
+
23
+ Question:
24
+ {question}
25
+ """
26
  )
27
+
28
+ def format_docs(docs):
29
+ return "\n\n".join(d.page_content for d in docs)
30
+
31
+ chain = (
32
+ {
33
+ "context": retriever | format_docs,
34
+ "question": RunnablePassthrough()
35
+ }
36
+ | prompt
37
+ | llm
38
+ | StrOutputParser()
39
+ )
40
+
41
+ return chain