Spaces:
Sleeping
Sleeping
fahmiaziz98
commited on
Commit
·
91c6bea
1
Parent(s):
ba900f0
init
Browse files- app.py +60 -20
- requirements.txt +10 -9
- src/{retriever/__init__.py → __init__.py} +0 -0
- src/indexing/{document_processor.py → document_processing.py} +0 -0
- src/indexing/vectore_store.py +16 -12
- src/llm/__init__.py +0 -0
- src/llm/llm_interface.py +10 -0
- src/state.py +7 -0
- src/tools_retrieval/__init__.py +0 -0
- src/{retriever → tools_retrieval}/retriever.py +13 -1
- src/workflow.py +128 -0
app.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
import os
|
| 2 |
import streamlit as st
|
| 3 |
-
from src.indexing.
|
| 4 |
from src.indexing.vectore_store import VectorStoreManager
|
| 5 |
-
from src.
|
|
|
|
|
|
|
| 6 |
|
| 7 |
UPLOAD_FOLDER = "uploads/"
|
| 8 |
-
PERSIST_DIRECTORY = "chroma_db
|
| 9 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 10 |
os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
|
| 11 |
|
|
@@ -15,6 +17,9 @@ if "retriever" not in st.session_state:
|
|
| 15 |
st.session_state.retriever = None
|
| 16 |
if "vector_store" not in st.session_state:
|
| 17 |
st.session_state.vector_store = None
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
st.set_page_config(
|
| 20 |
page_title="RAG Chatbot",
|
|
@@ -23,36 +28,71 @@ st.set_page_config(
|
|
| 23 |
)
|
| 24 |
st.title("Agentic RAG Chatbot")
|
| 25 |
|
| 26 |
-
|
| 27 |
with st.sidebar:
|
| 28 |
st.header("PDF Upload")
|
| 29 |
uploaded_file = st.file_uploader("Upload your PDF", type=["pdf"])
|
| 30 |
st.info("Supported file type: PDF")
|
|
|
|
| 31 |
|
| 32 |
-
if uploaded_file:
|
| 33 |
-
with st.spinner("Processing PDF..."):
|
| 34 |
-
|
| 35 |
file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
|
| 36 |
with open(file_path, "wb") as f:
|
| 37 |
f.write(uploaded_file.getbuffer())
|
| 38 |
|
| 39 |
-
|
| 40 |
doc_processor = DocumentProcessor()
|
| 41 |
chunks = doc_processor.load_and_split_pdf(file_path)
|
| 42 |
|
| 43 |
-
# Buat vector store
|
| 44 |
vector_store_manager = VectorStoreManager()
|
| 45 |
-
vector_store = vector_store_manager.index_documents(
|
| 46 |
-
|
| 47 |
-
collection_name=uploaded_file.name,
|
| 48 |
-
persist_directory=PERSIST_DIRECTORY
|
| 49 |
-
)
|
| 50 |
st.session_state.vector_store = vector_store
|
| 51 |
-
|
| 52 |
-
|
| 53 |
retriever_manager = RetrieverManager(vector_store)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import streamlit as st
|
| 3 |
+
from src.indexing.document_processing import DocumentProcessor
|
| 4 |
from src.indexing.vectore_store import VectorStoreManager
|
| 5 |
+
from src.tools_retrieval.retriever import RetrieverManager
|
| 6 |
+
from src.workflow import RAGWorkflow
|
| 7 |
+
|
| 8 |
|
| 9 |
UPLOAD_FOLDER = "uploads/"
|
| 10 |
+
PERSIST_DIRECTORY = "./chroma_db"
|
| 11 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 12 |
os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
|
| 13 |
|
|
|
|
| 17 |
st.session_state.retriever = None
|
| 18 |
if "vector_store" not in st.session_state:
|
| 19 |
st.session_state.vector_store = None
|
| 20 |
+
if "workflow" not in st.session_state:
|
| 21 |
+
st.session_state.workflow = None
|
| 22 |
+
|
| 23 |
|
| 24 |
st.set_page_config(
|
| 25 |
page_title="RAG Chatbot",
|
|
|
|
| 28 |
)
|
| 29 |
st.title("Agentic RAG Chatbot")
|
| 30 |
|
|
|
|
| 31 |
with st.sidebar:
|
| 32 |
st.header("PDF Upload")
|
| 33 |
uploaded_file = st.file_uploader("Upload your PDF", type=["pdf"])
|
| 34 |
st.info("Supported file type: PDF")
|
| 35 |
+
process_button = st.button("Process PDF")
|
| 36 |
|
| 37 |
+
if uploaded_file and process_button:
|
| 38 |
+
with st.spinner("Processing PDF..."):
|
|
|
|
| 39 |
file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
|
| 40 |
with open(file_path, "wb") as f:
|
| 41 |
f.write(uploaded_file.getbuffer())
|
| 42 |
|
|
|
|
| 43 |
doc_processor = DocumentProcessor()
|
| 44 |
chunks = doc_processor.load_and_split_pdf(file_path)
|
| 45 |
|
|
|
|
| 46 |
vector_store_manager = VectorStoreManager()
|
| 47 |
+
vector_store = vector_store_manager.index_documents(chunks, uploaded_file.name, PERSIST_DIRECTORY)
|
| 48 |
+
|
|
|
|
|
|
|
|
|
|
| 49 |
st.session_state.vector_store = vector_store
|
| 50 |
+
st.success("PDF processed and indexed successfully!")
|
| 51 |
+
|
| 52 |
retriever_manager = RetrieverManager(vector_store)
|
| 53 |
+
retriever_tool = retriever_manager.create_retriever(chunks)
|
| 54 |
+
st.session_state.retriever = retriever_tool
|
| 55 |
+
st.success("Retriever tool created successfully!")
|
| 56 |
+
rag_workflow = RAGWorkflow(retriever_tool)
|
| 57 |
+
workflow = rag_workflow.compile()
|
| 58 |
+
st.session_state.workflow = workflow
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Display chat messages
|
| 63 |
+
for message in st.session_state.messages:
|
| 64 |
+
with st.chat_message(message["role"]):
|
| 65 |
+
st.markdown(message["content"])
|
| 66 |
+
|
| 67 |
+
if prompt := st.chat_input("Ask a question about your document"):
|
| 68 |
+
# Add user message to chat history
|
| 69 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 70 |
+
with st.chat_message("user"):
|
| 71 |
+
st.markdown(prompt)
|
| 72 |
+
|
| 73 |
+
# Generate response
|
| 74 |
+
with st.chat_message("assistant"):
|
| 75 |
+
if st.session_state.retriever is None:
|
| 76 |
+
final_response = "Please upload a PDF document first."
|
| 77 |
+
else:
|
| 78 |
+
with st.spinner("Thinking..."):
|
| 79 |
+
# Retrieve relevant documents
|
| 80 |
+
inputs = {
|
| 81 |
+
"messages": [
|
| 82 |
+
("user", prompt),
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Generate response using workflow
|
| 87 |
+
if st.session_state.workflow is not None:
|
| 88 |
+
response = st.session_state.workflow.invoke(inputs)
|
| 89 |
+
final_response = response["messages"][-1].content
|
| 90 |
+
else:
|
| 91 |
+
final_response = "Please upload a PDF document first."
|
| 92 |
+
|
| 93 |
+
st.markdown(final_response)
|
| 94 |
+
st.session_state.messages.append({"role": "assistant", "content": final_response})
|
| 95 |
|
| 96 |
+
# Add clear chat button
|
| 97 |
+
if st.sidebar.button("Clear Chat"):
|
| 98 |
+
st.session_state.messages = []
|
requirements.txt
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
-
langchain
|
| 2 |
langgraph
|
| 3 |
-
langchain-huggingface
|
| 4 |
-
langchain-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
langchain-chroma
|
| 8 |
-
pypdf
|
| 9 |
-
tiktoken
|
| 10 |
-
rank_bm25
|
| 11 |
-
|
|
|
|
|
|
| 1 |
+
langchain
|
| 2 |
langgraph
|
| 3 |
+
langchain-huggingface
|
| 4 |
+
langchain-groq
|
| 5 |
+
langchain-community
|
| 6 |
+
scikit-learn
|
| 7 |
langchain-chroma
|
| 8 |
+
pypdf==5.1.0
|
| 9 |
+
tiktoken
|
| 10 |
+
rank_bm25
|
| 11 |
+
fastembed
|
| 12 |
+
flashrank
|
src/{retriever/__init__.py → __init__.py}
RENAMED
|
File without changes
|
src/indexing/{document_processor.py → document_processing.py}
RENAMED
|
File without changes
|
src/indexing/vectore_store.py
CHANGED
|
@@ -1,23 +1,27 @@
|
|
| 1 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
| 2 |
from langchain_chroma import Chroma
|
| 3 |
|
| 4 |
class VectorStoreManager:
|
| 5 |
def __init__(self, embedding_model="intfloat/multilingual-e5-small"):
|
| 6 |
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
|
| 7 |
|
| 8 |
-
def create_vector_store(self, collection_name
|
| 9 |
"""Create a new vector store"""
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
def index_documents(self, documents, collection_name
|
| 20 |
"""Index documents into vector store"""
|
| 21 |
-
vector_store = self.create_vector_store(collection_name,
|
| 22 |
vector_store.add_documents(documents=documents)
|
| 23 |
-
return vector_store
|
|
|
|
| 1 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
from langchain_community.vectorstores import SKLearnVectorStore
|
| 3 |
from langchain_chroma import Chroma
|
| 4 |
|
| 5 |
class VectorStoreManager:
|
| 6 |
def __init__(self, embedding_model="intfloat/multilingual-e5-small"):
|
| 7 |
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
|
| 8 |
|
| 9 |
+
def create_vector_store(self, collection_name, presist_directory):
|
| 10 |
"""Create a new vector store"""
|
| 11 |
+
# vector_store = SKLearnVectorStore.from_documents(
|
| 12 |
+
# documents=documents,
|
| 13 |
+
# embedding=self.embeddings,
|
| 14 |
+
# )
|
| 15 |
+
vector_store = Chroma(
|
| 16 |
+
collection_name=collection_name,
|
| 17 |
+
embedding_function=self.embeddings,
|
| 18 |
+
persist_directory=presist_directory, # Where to save data locally, remove if not necessary
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
return vector_store
|
| 22 |
|
| 23 |
+
def index_documents(self, documents, collection_name, presist_directory):
|
| 24 |
"""Index documents into vector store"""
|
| 25 |
+
vector_store = self.create_vector_store(collection_name, presist_directory)
|
| 26 |
vector_store.add_documents(documents=documents)
|
| 27 |
+
return vector_store
|
src/llm/__init__.py
ADDED
|
File without changes
|
src/llm/llm_interface.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_groq import ChatGroq
|
| 3 |
+
|
| 4 |
+
llm_groq = ChatGroq(
|
| 5 |
+
model="llama3-8b-8192",
|
| 6 |
+
temperature=0.1,
|
| 7 |
+
api_key=os.getenv("GROQ_API_KEY"),
|
| 8 |
+
max_retries=3,
|
| 9 |
+
streaming=True,
|
| 10 |
+
)
|
src/state.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated, Sequence
|
| 2 |
+
from typing_extensions import TypedDict
|
| 3 |
+
from langchain_core.messages import BaseMessage
|
| 4 |
+
from langgraph.graph.message import add_messages
|
| 5 |
+
|
| 6 |
+
class AgentState(TypedDict):
|
| 7 |
+
messages: Annotated[Sequence[BaseMessage], add_messages]
|
src/tools_retrieval/__init__.py
ADDED
|
File without changes
|
src/{retriever → tools_retrieval}/retriever.py
RENAMED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 3 |
from langchain.retrievers.document_compressors import FlashrankRerank
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class RetrieverManager:
|
| 7 |
def __init__(self, vector_store):
|
|
@@ -31,4 +34,13 @@ class RetrieverManager:
|
|
| 31 |
return ContextualCompressionRetriever(
|
| 32 |
base_compressor=compressor,
|
| 33 |
base_retriever=base_retriever
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
)
|
|
|
|
| 1 |
+
from langchain_community.retrievers import BM25Retriever
|
| 2 |
+
from langchain.retrievers import EnsembleRetriever
|
| 3 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 4 |
from langchain.retrievers.document_compressors import FlashrankRerank
|
| 5 |
|
| 6 |
+
from langchain.tools.retriever import create_retriever_tool
|
| 7 |
+
|
| 8 |
|
| 9 |
class RetrieverManager:
|
| 10 |
def __init__(self, vector_store):
|
|
|
|
| 34 |
return ContextualCompressionRetriever(
|
| 35 |
base_compressor=compressor,
|
| 36 |
base_retriever=base_retriever
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def create_retriever(self, documents):
|
| 40 |
+
base_retriever = self.create_ensemble_retriever(documents)
|
| 41 |
+
compression_retriever = self.create_compression_retriever(base_retriever=base_retriever)
|
| 42 |
+
return create_retriever_tool(
|
| 43 |
+
compression_retriever,
|
| 44 |
+
"retrieve_docs",
|
| 45 |
+
"use tools for search through the user's provided documents and return relevant information about user query.",
|
| 46 |
)
|
src/workflow.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain import hub
|
| 2 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 3 |
+
from langchain_core.messages import HumanMessage
|
| 4 |
+
from typing import Literal
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
from langchain_core.prompts import PromptTemplate
|
| 7 |
+
|
| 8 |
+
from langgraph.graph import END, StateGraph, START
|
| 9 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 10 |
+
from .state import AgentState
|
| 11 |
+
from src.llm.llm_interface import llm_groq
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GradeDocs(BaseModel):
|
| 15 |
+
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RAGWorkflow:
|
| 20 |
+
def __init__(self, retriever_tool):
|
| 21 |
+
self.workflow = StateGraph(AgentState)
|
| 22 |
+
self.tools = [retriever_tool]
|
| 23 |
+
self.retrieve = ToolNode([retriever_tool])
|
| 24 |
+
self._setup_nodes()
|
| 25 |
+
self._setup_edges()
|
| 26 |
+
|
| 27 |
+
def _setup_nodes(self):
|
| 28 |
+
self.workflow.add_node("agent", self._agent_node)
|
| 29 |
+
self.workflow.add_node("retrieve", self.retrieve)
|
| 30 |
+
self.workflow.add_node("generate", self._generator_node)
|
| 31 |
+
|
| 32 |
+
self.workflow.add_node("rewrite", self._rewrite_node)
|
| 33 |
+
|
| 34 |
+
def _setup_edges(self):
|
| 35 |
+
self.workflow.add_edge(START, "agent")
|
| 36 |
+
self.workflow.add_conditional_edges(
|
| 37 |
+
"agent",
|
| 38 |
+
tools_condition,
|
| 39 |
+
{
|
| 40 |
+
"tools": "retrieve",
|
| 41 |
+
END: END
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.workflow.add_conditional_edges(
|
| 46 |
+
"retrieve",
|
| 47 |
+
self._grade_docs,
|
| 48 |
+
)
|
| 49 |
+
self.workflow.add_edge("generate", END)
|
| 50 |
+
self.workflow.add_edge("rewrite", "agent")
|
| 51 |
+
|
| 52 |
+
def compile(self):
|
| 53 |
+
return self.workflow.compile()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _agent_node(self, state):
|
| 57 |
+
print("---CALL AGENT---")
|
| 58 |
+
messages = state["messages"]
|
| 59 |
+
|
| 60 |
+
model = llm_groq.bind_tools(self.tools)
|
| 61 |
+
response = model.invoke(messages[0].content)
|
| 62 |
+
return {"messages": [response]}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _generator_node(self, state):
|
| 66 |
+
print("---GENERATE---")
|
| 67 |
+
messages = state["messages"]
|
| 68 |
+
question = messages[0].content
|
| 69 |
+
docs = messages[-1].content
|
| 70 |
+
|
| 71 |
+
prompt = hub.pull("rlm/rag-prompt")
|
| 72 |
+
rag_chain = prompt | llm_groq | StrOutputParser()
|
| 73 |
+
|
| 74 |
+
response = rag_chain.invoke({"context": docs, "question": question})
|
| 75 |
+
return {"messages": [response]}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _rewrite_node(self, state):
|
| 80 |
+
print("---REWRITE---")
|
| 81 |
+
messages = state["messages"]
|
| 82 |
+
question = messages[0].content
|
| 83 |
+
|
| 84 |
+
msg = [
|
| 85 |
+
HumanMessage(
|
| 86 |
+
content=f""" \n
|
| 87 |
+
Look at the input and try to reason about the underlying semantic intent / meaning. \n
|
| 88 |
+
Here is the initial question:
|
| 89 |
+
\n ------- \n
|
| 90 |
+
{question}
|
| 91 |
+
\n ------- \n
|
| 92 |
+
Formulate an improved question: """,
|
| 93 |
+
)
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
response = llm_groq.invoke(msg)
|
| 97 |
+
return {"messages": [response]}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _grade_docs(self, state):
|
| 102 |
+
print("---CHECK RELEVANCE---")
|
| 103 |
+
|
| 104 |
+
llm_with_tool = llm_groq.with_structured_output(GradeDocs)
|
| 105 |
+
|
| 106 |
+
prompt = PromptTemplate(
|
| 107 |
+
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
|
| 108 |
+
Here is the retrieved document: \n\n {context} \n\n
|
| 109 |
+
Here is the user question: {question} \n
|
| 110 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
| 111 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
|
| 112 |
+
input_variables=["context", "question"],
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
chain = prompt | llm_with_tool
|
| 116 |
+
|
| 117 |
+
messages = state["messages"]
|
| 118 |
+
question = messages[0].content
|
| 119 |
+
docs = messages[-1].content
|
| 120 |
+
|
| 121 |
+
scored_result = chain.invoke({"question": question, "context": docs})
|
| 122 |
+
|
| 123 |
+
if scored_result.binary_score == "yes":
|
| 124 |
+
print("---DECISION: DOCS RELEVANT---")
|
| 125 |
+
return "generate"
|
| 126 |
+
print("---DECISION: DOCS NOT RELEVANT---")
|
| 127 |
+
return "rewrite"
|
| 128 |
+
|