| import os |
| os.environ['DISABLE_TORCH_SCALED_DOT_PRODUCT_ATTENTION'] = '1' |
| os.environ['TORCH_USE_CPU_DSA'] = '0' |
|
|
| from langchain_community.document_loaders import TextLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.vectorstores import FAISS |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
| from langchain_community.llms import HuggingFacePipeline |
| from langchain.memory import ConversationBufferWindowMemory |
|
|
| import pandas as pd |
| from langchain.schema import Document |
|
|
| |
| def load_documents(path="data/lawbot/Final_Dataset.pkl"): |
| df = pd.read_pickle(path) |
| |
| |
| documents = [Document(page_content=row['response']) for _, row in df.head(1000).iterrows()] |
| return documents |
|
|
| |
| def split_documents(documents): |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
| docs = text_splitter.split_documents(documents) |
| return docs |
|
|
| |
| def create_vectorstore(docs): |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", encode_kwargs={'batch_size': 1}) |
| vectorstore = FAISS.from_documents(docs, embeddings) |
| return vectorstore |
|
|
| |
| def setup_rag_chain(vectorstore): |
| |
| |
| |
| llm = HuggingFacePipeline.from_model_id(model_id="google/flan-t5-base", task="text2text-generation", device=-1, model_kwargs={"temperature":0, "max_length":512}) |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":3}) |
| memory = ConversationBufferWindowMemory(k=10, memory_key="chat_history", return_messages=True) |
| qa_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory) |
| return qa_chain |
|
|
| |
| def answer_query(query, qa_chain, chat_history=None): |
| if chat_history is None: |
| chat_history = [] |
| result = qa_chain.invoke({"question": query, "chat_history": chat_history}) |
| return result["answer"] |
|
|
| if __name__ == "__main__": |
| documents = load_documents() |
| docs = split_documents(documents) |
| vectorstore = create_vectorstore(docs) |
| qa_chain = setup_rag_chain(vectorstore) |
| query = "What is Section 498A?" |
| answer = answer_query(query, qa_chain) |
| print(f"Q: {query}\nA: {answer}") |