Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- app.py +94 -0
- doc_preprocessing.py +37 -0
- indexing.py +41 -0
- requirements.txt +0 -0
- retrieval.py +32 -0
app.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from doc_preprocessing import load_and_split_document
|
| 4 |
+
from indexing import initialize_pinecone, delete_index
|
| 5 |
+
from retrieval import retrieve_documents
|
| 6 |
+
from langchain_cohere import CohereEmbeddings, ChatCohere
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from langchain_pinecone import PineconeVectorStore
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
# Set API keys
|
| 15 |
+
cohere_api = os.getenv("COHERE_API_KEY")
|
| 16 |
+
pinecone_api = os.getenv("PINECONE_API_KEY")
|
| 17 |
+
cohere_chat_model = ChatCohere(cohere_api_key=cohere_api)
|
| 18 |
+
cohere_embeddings = CohereEmbeddings(cohere_api_key=cohere_api, user_agent="my-app", model="embed-english-v2.0")
|
| 19 |
+
|
| 20 |
+
def pretty_print_docs(docs):
|
| 21 |
+
return "\n\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)])
|
| 22 |
+
|
| 23 |
+
# Initialize session state
|
| 24 |
+
if "index_name" not in st.session_state:
|
| 25 |
+
st.session_state.index_name = None
|
| 26 |
+
if "retriever" not in st.session_state:
|
| 27 |
+
st.session_state.retriever = None
|
| 28 |
+
|
| 29 |
+
st.title("RAG-Based Document Search with LangChain")
|
| 30 |
+
|
| 31 |
+
# Upload PDF or DOCX document
|
| 32 |
+
uploaded_file = st.file_uploader("Upload a PDF or DOCX Document", type=["pdf", "docx"])
|
| 33 |
+
|
| 34 |
+
# Input for user query
|
| 35 |
+
query = st.text_input("Ask a question related to the uploaded document:")
|
| 36 |
+
|
| 37 |
+
if uploaded_file is not None and st.session_state.index_name is None:
|
| 38 |
+
# Detect file type
|
| 39 |
+
file_type = uploaded_file.name.split(".")[-1].lower()
|
| 40 |
+
|
| 41 |
+
# Create a unique index name for the session
|
| 42 |
+
user_index = f"user-{str(time.time()).replace('.', '-')}"
|
| 43 |
+
st.session_state.index_name = user_index
|
| 44 |
+
|
| 45 |
+
# # Save the uploaded file to the "data" directory
|
| 46 |
+
# file_path = os.path.join("C:/Users/ADMIN/Desktop/rag_assignment/data", uploaded_file.name)
|
| 47 |
+
# with open(file_path, "wb") as f:
|
| 48 |
+
# f.write(uploaded_file.getbuffer())
|
| 49 |
+
# Save the uploaded file to a container-friendly path
|
| 50 |
+
file_path = os.path.join("data", uploaded_file.name) # Use relative path
|
| 51 |
+
os.makedirs("data", exist_ok=True) # Create the 'data' directory if it doesn't exist
|
| 52 |
+
|
| 53 |
+
with open(file_path, "wb") as f:
|
| 54 |
+
f.write(uploaded_file.getbuffer())
|
| 55 |
+
|
| 56 |
+
# Load and split the document, converting if necessary
|
| 57 |
+
documents = load_and_split_document(file_path, file_type)
|
| 58 |
+
|
| 59 |
+
# Initialize Pinecone index
|
| 60 |
+
index = initialize_pinecone(pinecone_api_key=pinecone_api, index_name=user_index)
|
| 61 |
+
db = PineconeVectorStore.from_documents(
|
| 62 |
+
documents=documents,
|
| 63 |
+
embedding=cohere_embeddings,
|
| 64 |
+
index_name=user_index,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Store the retriever in session state
|
| 68 |
+
st.session_state.retriever = db.as_retriever(search_kwargs={"k": 5})
|
| 69 |
+
st.write("Data Indexed Successfully")
|
| 70 |
+
|
| 71 |
+
# Add a submit button for query input
|
| 72 |
+
if st.session_state.retriever:
|
| 73 |
+
if st.button("Submit"):
|
| 74 |
+
# Retrieve documents based on the query
|
| 75 |
+
result = retrieve_documents(query=query, retriever=st.session_state.retriever, llm=cohere_chat_model)
|
| 76 |
+
|
| 77 |
+
st.header("Response:")
|
| 78 |
+
st.write(result["answer"])
|
| 79 |
+
|
| 80 |
+
st.write("-------------------------------------------------------------------")
|
| 81 |
+
|
| 82 |
+
st.header("Context:")
|
| 83 |
+
if "I don't know" in result["answer"]:
|
| 84 |
+
st.markdown("Can't fetch the context!!")
|
| 85 |
+
else:
|
| 86 |
+
st.markdown(pretty_print_docs(result["context"]))
|
| 87 |
+
|
| 88 |
+
# Clean up index when user ends the session
|
| 89 |
+
if st.button("End Session and Delete Index"):
|
| 90 |
+
if st.session_state.index_name:
|
| 91 |
+
delete_index(st.session_state.index_name, pinecone_api)
|
| 92 |
+
st.success(f"Index '{st.session_state.index_name}' deleted.")
|
| 93 |
+
st.session_state.index_name = None
|
| 94 |
+
st.session_state.retriever = None
|
doc_preprocessing.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# document_processing.py
|
| 2 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 4 |
+
from docx import Document
|
| 5 |
+
import pdfkit
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def convert_docx_to_pdf(docx_file, pdf_file):
|
| 9 |
+
"""
|
| 10 |
+
Convert .docx file to a .pdf using pdfkit.
|
| 11 |
+
"""
|
| 12 |
+
document = Document(docx_file)
|
| 13 |
+
document.save(f"{docx_file}")
|
| 14 |
+
|
| 15 |
+
# Convert the docx file to pdf using pdfkit
|
| 16 |
+
pdfkit.from_file(docx_file, pdf_file)
|
| 17 |
+
|
| 18 |
+
def load_and_split_document(file_path, file_type):
|
| 19 |
+
"""
|
| 20 |
+
Handles PDF and DOCX files. If DOCX, it converts to PDF first,
|
| 21 |
+
then processes the document.
|
| 22 |
+
"""
|
| 23 |
+
# Convert DOCX to PDF if necessary
|
| 24 |
+
if file_type == "docx":
|
| 25 |
+
pdf_file = file_path.replace(".docx", ".pdf")
|
| 26 |
+
convert_docx_to_pdf(file_path, pdf_file)
|
| 27 |
+
file_path = pdf_file # Update file path to newly created PDF
|
| 28 |
+
|
| 29 |
+
# Load the PDF document
|
| 30 |
+
loader = PyPDFLoader(file_path)
|
| 31 |
+
raw_documents = loader.load()
|
| 32 |
+
|
| 33 |
+
# Chunk the text using recursive character splitter
|
| 34 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200)
|
| 35 |
+
documents = text_splitter.split_documents(raw_documents)
|
| 36 |
+
|
| 37 |
+
return documents
|
indexing.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# indexing.py
|
| 2 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
# Initialize Pinecone and create unique index
|
| 6 |
+
def initialize_pinecone(pinecone_api_key, index_name):
|
| 7 |
+
|
| 8 |
+
spec = ServerlessSpec(
|
| 9 |
+
cloud="aws",
|
| 10 |
+
region="us-east-1"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
pinecone_api = pinecone_api_key
|
| 14 |
+
pc = Pinecone(api_key=pinecone_api)
|
| 15 |
+
|
| 16 |
+
existing_indexes = [
|
| 17 |
+
index_info["name"] for index_info in pc.list_indexes()]
|
| 18 |
+
|
| 19 |
+
# check if index already exists (it shouldn't if this is first time)
|
| 20 |
+
if index_name not in existing_indexes:
|
| 21 |
+
# if does not exist, create index
|
| 22 |
+
pc.create_index(
|
| 23 |
+
index_name,
|
| 24 |
+
dimension=4096, # dimensionality of ada 002
|
| 25 |
+
metric='dotproduct',
|
| 26 |
+
spec=spec
|
| 27 |
+
)
|
| 28 |
+
# wait for index to be initialized
|
| 29 |
+
while not pc.describe_index(index_name).status['ready']:
|
| 30 |
+
time.sleep(1)
|
| 31 |
+
|
| 32 |
+
# connect to index
|
| 33 |
+
index = pc.Index(index_name)
|
| 34 |
+
time.sleep(1)
|
| 35 |
+
|
| 36 |
+
return index
|
| 37 |
+
|
| 38 |
+
# Delete Pinecone index when user quits
|
| 39 |
+
def delete_index(index_name, pinecone_api_key):
|
| 40 |
+
pc = Pinecone(api_key=pinecone_api_key)
|
| 41 |
+
pc.delete_index(index_name)
|
requirements.txt
ADDED
|
Binary file (3.91 kB). View file
|
|
|
retrieval.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# retrieval.py
|
| 2 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 3 |
+
from langchain_cohere import CohereRerank
|
| 4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 5 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 6 |
+
from langchain.chains import create_retrieval_chain
|
| 7 |
+
|
| 8 |
+
def retrieve_documents(query, retriever, llm):
|
| 9 |
+
|
| 10 |
+
# Apply Cohere reranking model
|
| 11 |
+
compressor = CohereRerank(model="rerank-english-v3.0")
|
| 12 |
+
compression_retriever = ContextualCompressionRetriever(
|
| 13 |
+
base_compressor=compressor, base_retriever=retriever
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
prompt = """You are a good assistant that answers questions. Your knowledge is strictly limited to the following piece of context. Use it to answer the question at the end.
|
| 17 |
+
If the answer can't be found in the context, just say you don't know. *DO NOT* try to make up an answer.
|
| 18 |
+
If the question is not related to the context, politely respond that you are tuned to only answer questions that are related to the context.
|
| 19 |
+
**MOST IMPORTANT: If question is not related to the context, just say "I don't know".**
|
| 20 |
+
|
| 21 |
+
Context: {context}
|
| 22 |
+
Question: {input}
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
prompt_template = ChatPromptTemplate.from_template(prompt)
|
| 27 |
+
|
| 28 |
+
document_chain = create_stuff_documents_chain(llm, prompt_template)
|
| 29 |
+
retrieval_chain = create_retrieval_chain(compression_retriever, document_chain)
|
| 30 |
+
response = retrieval_chain.invoke({"input":query})
|
| 31 |
+
|
| 32 |
+
return response
|