Alpha108 commited on
Commit
496f188
·
verified ·
1 Parent(s): b3dfc21

Update rag_utils.py

Browse files
Files changed (1) hide show
  1. rag_utils.py +10 -17
rag_utils.py CHANGED
@@ -1,27 +1,20 @@
1
- # rag_utils.py
2
-
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_groq import ChatGroq
 
8
 
9
- def create_vectorstore_from_text(text: str):
10
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
11
- texts = splitter.split_text(text)
12
-
13
- embeddings = HuggingFaceEmbeddings(
14
- model_name="sentence-transformers/all-MiniLM-L6-v2",
15
- model_kwargs={"device": "cpu"}
16
- )
17
 
18
- vectorstore = FAISS.from_texts(texts, embedding=embeddings)
19
  return vectorstore
20
 
21
- def create_rag_chain(vectorstore):
22
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
23
-
24
- llm = ChatGroq(model_name="llama3-8b-8192", temperature=0)
25
-
26
- rag_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
27
- return rag_chain
 
 
 
1
  from langchain.text_splitter import RecursiveCharacterTextSplitter
2
  from langchain.vectorstores import FAISS
3
  from langchain.chains import RetrievalQA
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_groq import ChatGroq
6
+ from langchain.docstore.document import Document
7
 
8
+ def create_vectorstore_from_text(documents, embeddings):
9
+ # If string is passed instead of list of Document, convert it
10
+ if isinstance(documents, str):
11
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
12
+ chunks = splitter.split_text(documents)
13
+ documents = [Document(page_content=chunk) for chunk in chunks]
 
 
14
 
15
+ vectorstore = FAISS.from_documents(documents, embedding=embeddings)
16
  return vectorstore
17
 
18
+ def create_rag_chain(llm, vectorstore):
19
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
20
+ return RetrievalQA.from_chain_type(llm=llm, retriever=retriever)