Harsh12 commited on
Commit
b0b3cb1
·
verified ·
1 Parent(s): 04df341

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +94 -0
  2. doc_preprocessing.py +37 -0
  3. indexing.py +41 -0
  4. requirements.txt +0 -0
  5. retrieval.py +32 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ from doc_preprocessing import load_and_split_document
4
+ from indexing import initialize_pinecone, delete_index
5
+ from retrieval import retrieve_documents
6
+ from langchain_cohere import CohereEmbeddings, ChatCohere
7
+ from dotenv import load_dotenv
8
+ import os
9
+ import time
10
+ from langchain_pinecone import PineconeVectorStore
11
+
12
+ load_dotenv()
13
+
14
+ # Set API keys
15
+ cohere_api = os.getenv("COHERE_API_KEY")
16
+ pinecone_api = os.getenv("PINECONE_API_KEY")
17
+ cohere_chat_model = ChatCohere(cohere_api_key=cohere_api)
18
+ cohere_embeddings = CohereEmbeddings(cohere_api_key=cohere_api, user_agent="my-app", model="embed-english-v2.0")
19
+
20
+ def pretty_print_docs(docs):
21
+ return "\n\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)])
22
+
23
+ # Initialize session state
24
+ if "index_name" not in st.session_state:
25
+ st.session_state.index_name = None
26
+ if "retriever" not in st.session_state:
27
+ st.session_state.retriever = None
28
+
29
+ st.title("RAG-Based Document Search with LangChain")
30
+
31
+ # Upload PDF or DOCX document
32
+ uploaded_file = st.file_uploader("Upload a PDF or DOCX Document", type=["pdf", "docx"])
33
+
34
+ # Input for user query
35
+ query = st.text_input("Ask a question related to the uploaded document:")
36
+
37
+ if uploaded_file is not None and st.session_state.index_name is None:
38
+ # Detect file type
39
+ file_type = uploaded_file.name.split(".")[-1].lower()
40
+
41
+ # Create a unique index name for the session
42
+ user_index = f"user-{str(time.time()).replace('.', '-')}"
43
+ st.session_state.index_name = user_index
44
+
45
+ # # Save the uploaded file to the "data" directory
46
+ # file_path = os.path.join("C:/Users/ADMIN/Desktop/rag_assignment/data", uploaded_file.name)
47
+ # with open(file_path, "wb") as f:
48
+ # f.write(uploaded_file.getbuffer())
49
+ # Save the uploaded file to a container-friendly path
50
+ file_path = os.path.join("data", uploaded_file.name) # Use relative path
51
+ os.makedirs("data", exist_ok=True) # Create the 'data' directory if it doesn't exist
52
+
53
+ with open(file_path, "wb") as f:
54
+ f.write(uploaded_file.getbuffer())
55
+
56
+ # Load and split the document, converting if necessary
57
+ documents = load_and_split_document(file_path, file_type)
58
+
59
+ # Initialize Pinecone index
60
+ index = initialize_pinecone(pinecone_api_key=pinecone_api, index_name=user_index)
61
+ db = PineconeVectorStore.from_documents(
62
+ documents=documents,
63
+ embedding=cohere_embeddings,
64
+ index_name=user_index,
65
+ )
66
+
67
+ # Store the retriever in session state
68
+ st.session_state.retriever = db.as_retriever(search_kwargs={"k": 5})
69
+ st.write("Data Indexed Successfully")
70
+
71
+ # Add a submit button for query input
72
+ if st.session_state.retriever:
73
+ if st.button("Submit"):
74
+ # Retrieve documents based on the query
75
+ result = retrieve_documents(query=query, retriever=st.session_state.retriever, llm=cohere_chat_model)
76
+
77
+ st.header("Response:")
78
+ st.write(result["answer"])
79
+
80
+ st.write("-------------------------------------------------------------------")
81
+
82
+ st.header("Context:")
83
+ if "I don't know" in result["answer"]:
84
+ st.markdown("Can't fetch the context!!")
85
+ else:
86
+ st.markdown(pretty_print_docs(result["context"]))
87
+
88
+ # Clean up index when user ends the session
89
+ if st.button("End Session and Delete Index"):
90
+ if st.session_state.index_name:
91
+ delete_index(st.session_state.index_name, pinecone_api)
92
+ st.success(f"Index '{st.session_state.index_name}' deleted.")
93
+ st.session_state.index_name = None
94
+ st.session_state.retriever = None
doc_preprocessing.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # document_processing.py
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from docx import Document
5
+ import pdfkit
6
+
7
+
8
+ def convert_docx_to_pdf(docx_file, pdf_file):
9
+ """
10
+ Convert .docx file to a .pdf using pdfkit.
11
+ """
12
+ document = Document(docx_file)
13
+ document.save(f"{docx_file}")
14
+
15
+ # Convert the docx file to pdf using pdfkit
16
+ pdfkit.from_file(docx_file, pdf_file)
17
+
18
+ def load_and_split_document(file_path, file_type):
19
+ """
20
+ Handles PDF and DOCX files. If DOCX, it converts to PDF first,
21
+ then processes the document.
22
+ """
23
+ # Convert DOCX to PDF if necessary
24
+ if file_type == "docx":
25
+ pdf_file = file_path.replace(".docx", ".pdf")
26
+ convert_docx_to_pdf(file_path, pdf_file)
27
+ file_path = pdf_file # Update file path to newly created PDF
28
+
29
+ # Load the PDF document
30
+ loader = PyPDFLoader(file_path)
31
+ raw_documents = loader.load()
32
+
33
+ # Chunk the text using recursive character splitter
34
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200)
35
+ documents = text_splitter.split_documents(raw_documents)
36
+
37
+ return documents
indexing.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # indexing.py
2
+ from pinecone import Pinecone, ServerlessSpec
3
+ import time
4
+
5
+ # Initialize Pinecone and create unique index
6
+ def initialize_pinecone(pinecone_api_key, index_name):
7
+
8
+ spec = ServerlessSpec(
9
+ cloud="aws",
10
+ region="us-east-1"
11
+ )
12
+
13
+ pinecone_api = pinecone_api_key
14
+ pc = Pinecone(api_key=pinecone_api)
15
+
16
+ existing_indexes = [
17
+ index_info["name"] for index_info in pc.list_indexes()]
18
+
19
+ # check if index already exists (it shouldn't if this is first time)
20
+ if index_name not in existing_indexes:
21
+ # if does not exist, create index
22
+ pc.create_index(
23
+ index_name,
24
+ dimension=4096, # dimensionality of ada 002
25
+ metric='dotproduct',
26
+ spec=spec
27
+ )
28
+ # wait for index to be initialized
29
+ while not pc.describe_index(index_name).status['ready']:
30
+ time.sleep(1)
31
+
32
+ # connect to index
33
+ index = pc.Index(index_name)
34
+ time.sleep(1)
35
+
36
+ return index
37
+
38
+ # Delete Pinecone index when user quits
39
+ def delete_index(index_name, pinecone_api_key):
40
+ pc = Pinecone(api_key=pinecone_api_key)
41
+ pc.delete_index(index_name)
requirements.txt ADDED
Binary file (3.91 kB). View file
 
retrieval.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retrieval.py
2
+ from langchain.retrievers import ContextualCompressionRetriever
3
+ from langchain_cohere import CohereRerank
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain.chains import create_retrieval_chain
7
+
8
+ def retrieve_documents(query, retriever, llm):
9
+
10
+ # Apply Cohere reranking model
11
+ compressor = CohereRerank(model="rerank-english-v3.0")
12
+ compression_retriever = ContextualCompressionRetriever(
13
+ base_compressor=compressor, base_retriever=retriever
14
+ )
15
+
16
+ prompt = """You are a good assistant that answers questions. Your knowledge is strictly limited to the following piece of context. Use it to answer the question at the end.
17
+ If the answer can't be found in the context, just say you don't know. *DO NOT* try to make up an answer.
18
+ If the question is not related to the context, politely respond that you are tuned to only answer questions that are related to the context.
19
+ **MOST IMPORTANT: If question is not related to the context, just say "I don't know".**
20
+
21
+ Context: {context}
22
+ Question: {input}
23
+
24
+ """
25
+
26
+ prompt_template = ChatPromptTemplate.from_template(prompt)
27
+
28
+ document_chain = create_stuff_documents_chain(llm, prompt_template)
29
+ retrieval_chain = create_retrieval_chain(compression_retriever, document_chain)
30
+ response = retrieval_chain.invoke({"input":query})
31
+
32
+ return response