Saathi / modules /lawbot /rag_with_langchain.py
MP44's picture
Update modules/lawbot/rag_with_langchain.py
1c7be06 verified
import os
os.environ['DISABLE_TORCH_SCALED_DOT_PRODUCT_ATTENTION'] = '1'
os.environ['TORCH_USE_CPU_DSA'] = '0'
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain_community.llms import HuggingFacePipeline
from langchain.memory import ConversationBufferWindowMemory
import pandas as pd
from langchain.schema import Document
# Load documents from Final_Dataset.pkl
def load_documents(path="data/lawbot/Final_Dataset.pkl"):
df = pd.read_pickle(path)
# Assuming 'response' column contains the text data
# Limit to first 1000 documents to avoid memory issues
documents = [Document(page_content=row['response']) for _, row in df.head(1000).iterrows()]
return documents
# Split documents into chunks for better retrieval
def split_documents(documents):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
docs = text_splitter.split_documents(documents)
return docs
# Create vector store from documents using HuggingFace embeddings
def create_vectorstore(docs):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", encode_kwargs={'batch_size': 1})
vectorstore = FAISS.from_documents(docs, embeddings)
return vectorstore
# Setup ConversationalRetrievalChain with HuggingFaceHub LLM and vectorstore retriever and memory
def setup_rag_chain(vectorstore):
# Use HuggingFaceEndpoint without repo_id to avoid StopIteration error if no provider found
# Instead, specify the endpoint URL or model_id explicitly if needed
# For now, instantiate without repo_id to avoid error, user should configure endpoint properly
llm = HuggingFacePipeline.from_model_id(model_id="google/flan-t5-base", task="text2text-generation", device=-1, model_kwargs={"temperature":0, "max_length":512})
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":3})
memory = ConversationBufferWindowMemory(k=10, memory_key="chat_history", return_messages=True)
qa_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)
return qa_chain
# Example usage
def answer_query(query, qa_chain, chat_history=None):
if chat_history is None:
chat_history = []
result = qa_chain.invoke({"question": query, "chat_history": chat_history})
return result["answer"]
if __name__ == "__main__":
documents = load_documents()
docs = split_documents(documents)
vectorstore = create_vectorstore(docs)
qa_chain = setup_rag_chain(vectorstore)
query = "What is Section 498A?"
answer = answer_query(query, qa_chain)
print(f"Q: {query}\nA: {answer}")