fahmiaziz98 commited on
Commit
91c6bea
·
1 Parent(s): ba900f0
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import os
2
  import streamlit as st
3
- from src.indexing.document_processor import DocumentProcessor
4
  from src.indexing.vectore_store import VectorStoreManager
5
- from src.retriever.retriever import RetrieverManager
 
 
6
 
7
  UPLOAD_FOLDER = "uploads/"
8
- PERSIST_DIRECTORY = "chroma_db/"
9
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
10
  os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
11
 
@@ -15,6 +17,9 @@ if "retriever" not in st.session_state:
15
  st.session_state.retriever = None
16
  if "vector_store" not in st.session_state:
17
  st.session_state.vector_store = None
 
 
 
18
 
19
  st.set_page_config(
20
  page_title="RAG Chatbot",
@@ -23,36 +28,71 @@ st.set_page_config(
23
  )
24
  st.title("Agentic RAG Chatbot")
25
 
26
-
27
  with st.sidebar:
28
  st.header("PDF Upload")
29
  uploaded_file = st.file_uploader("Upload your PDF", type=["pdf"])
30
  st.info("Supported file type: PDF")
 
31
 
32
- if uploaded_file:
33
- with st.spinner("Processing PDF..."):
34
-
35
  file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
36
  with open(file_path, "wb") as f:
37
  f.write(uploaded_file.getbuffer())
38
 
39
-
40
  doc_processor = DocumentProcessor()
41
  chunks = doc_processor.load_and_split_pdf(file_path)
42
 
43
- # Buat vector store
44
  vector_store_manager = VectorStoreManager()
45
- vector_store = vector_store_manager.index_documents(
46
- documents=chunks,
47
- collection_name=uploaded_file.name,
48
- persist_directory=PERSIST_DIRECTORY
49
- )
50
  st.session_state.vector_store = vector_store
51
-
52
- # Setup retriever
53
  retriever_manager = RetrieverManager(vector_store)
54
- base_retriever = retriever_manager.create_base_retriever()
55
- compression_retriever = retriever_manager.create_compression_retriever(base_retriever)
56
- st.session_state.retriever = compression_retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- st.success("File processed successfully!")
 
 
 
1
  import os
2
  import streamlit as st
3
+ from src.indexing.document_processing import DocumentProcessor
4
  from src.indexing.vectore_store import VectorStoreManager
5
+ from src.tools_retrieval.retriever import RetrieverManager
6
+ from src.workflow import RAGWorkflow
7
+
8
 
9
  UPLOAD_FOLDER = "uploads/"
10
+ PERSIST_DIRECTORY = "./chroma_db"
11
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
12
  os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
13
 
 
17
  st.session_state.retriever = None
18
  if "vector_store" not in st.session_state:
19
  st.session_state.vector_store = None
20
+ if "workflow" not in st.session_state:
21
+ st.session_state.workflow = None
22
+
23
 
24
  st.set_page_config(
25
  page_title="RAG Chatbot",
 
28
  )
29
  st.title("Agentic RAG Chatbot")
30
 
 
31
  with st.sidebar:
32
  st.header("PDF Upload")
33
  uploaded_file = st.file_uploader("Upload your PDF", type=["pdf"])
34
  st.info("Supported file type: PDF")
35
+ process_button = st.button("Process PDF")
36
 
37
+ if uploaded_file and process_button:
38
+ with st.spinner("Processing PDF..."):
 
39
  file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
40
  with open(file_path, "wb") as f:
41
  f.write(uploaded_file.getbuffer())
42
 
 
43
  doc_processor = DocumentProcessor()
44
  chunks = doc_processor.load_and_split_pdf(file_path)
45
 
 
46
  vector_store_manager = VectorStoreManager()
47
+ vector_store = vector_store_manager.index_documents(chunks, uploaded_file.name, PERSIST_DIRECTORY)
48
+
 
 
 
49
  st.session_state.vector_store = vector_store
50
+ st.success("PDF processed and indexed successfully!")
51
+
52
  retriever_manager = RetrieverManager(vector_store)
53
+ retriever_tool = retriever_manager.create_retriever(chunks)
54
+ st.session_state.retriever = retriever_tool
55
+ st.success("Retriever tool created successfully!")
56
+ rag_workflow = RAGWorkflow(retriever_tool)
57
+ workflow = rag_workflow.compile()
58
+ st.session_state.workflow = workflow
59
+
60
+
61
+
62
+ # Display chat messages
63
+ for message in st.session_state.messages:
64
+ with st.chat_message(message["role"]):
65
+ st.markdown(message["content"])
66
+
67
+ if prompt := st.chat_input("Ask a question about your document"):
68
+ # Add user message to chat history
69
+ st.session_state.messages.append({"role": "user", "content": prompt})
70
+ with st.chat_message("user"):
71
+ st.markdown(prompt)
72
+
73
+ # Generate response
74
+ with st.chat_message("assistant"):
75
+ if st.session_state.retriever is None:
76
+ final_response = "Please upload a PDF document first."
77
+ else:
78
+ with st.spinner("Thinking..."):
79
+ # Retrieve relevant documents
80
+ inputs = {
81
+ "messages": [
82
+ ("user", prompt),
83
+ ]
84
+ }
85
+
86
+ # Generate response using workflow
87
+ if st.session_state.workflow is not None:
88
+ response = st.session_state.workflow.invoke(inputs)
89
+ final_response = response["messages"][-1].content
90
+ else:
91
+ final_response = "Please upload a PDF document first."
92
+
93
+ st.markdown(final_response)
94
+ st.session_state.messages.append({"role": "assistant", "content": final_response})
95
 
96
+ # Add clear chat button
97
+ if st.sidebar.button("Clear Chat"):
98
+ st.session_state.messages = []
requirements.txt CHANGED
@@ -1,11 +1,12 @@
1
- langchain
2
  langgraph
3
- langchain-huggingface
4
- langchain-google-genai
5
- google-ai-generativelanguage==0.6.15
6
- langchain-community
7
  langchain-chroma
8
- pypdf
9
- tiktoken
10
- rank_bm25
11
- flashrank
 
 
1
+ langchain
2
  langgraph
3
+ langchain-huggingface
4
+ langchain-groq
5
+ langchain-community
6
+ scikit-learn
7
  langchain-chroma
8
+ pypdf==5.1.0
9
+ tiktoken
10
+ rank_bm25
11
+ fastembed
12
+ flashrank
src/{retriever/__init__.py → __init__.py} RENAMED
File without changes
src/indexing/{document_processor.py → document_processing.py} RENAMED
File without changes
src/indexing/vectore_store.py CHANGED
@@ -1,23 +1,27 @@
1
  from langchain_huggingface import HuggingFaceEmbeddings
 
2
  from langchain_chroma import Chroma
3
 
4
  class VectorStoreManager:
5
  def __init__(self, embedding_model="intfloat/multilingual-e5-small"):
6
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
7
 
8
- def create_vector_store(self, collection_name="my_collection", persist_directory=None):
9
  """Create a new vector store"""
10
- store_params = {
11
- "collection_name": collection_name,
12
- "embedding_function": self.embeddings,
13
- }
14
- if persist_directory:
15
- store_params["persist_directory"] = persist_directory
16
-
17
- return Chroma(**store_params)
 
 
 
18
 
19
- def index_documents(self, documents, collection_name="my_collection", persist_directory=None):
20
  """Index documents into vector store"""
21
- vector_store = self.create_vector_store(collection_name, persist_directory)
22
  vector_store.add_documents(documents=documents)
23
- return vector_store
 
1
  from langchain_huggingface import HuggingFaceEmbeddings
2
+ from langchain_community.vectorstores import SKLearnVectorStore
3
  from langchain_chroma import Chroma
4
 
5
  class VectorStoreManager:
6
  def __init__(self, embedding_model="intfloat/multilingual-e5-small"):
7
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
8
 
9
+ def create_vector_store(self, collection_name, presist_directory):
10
  """Create a new vector store"""
11
+ # vector_store = SKLearnVectorStore.from_documents(
12
+ # documents=documents,
13
+ # embedding=self.embeddings,
14
+ # )
15
+ vector_store = Chroma(
16
+ collection_name=collection_name,
17
+ embedding_function=self.embeddings,
18
+ persist_directory=presist_directory, # Where to save data locally, remove if not necessary
19
+ )
20
+
21
+ return vector_store
22
 
23
+ def index_documents(self, documents, collection_name, presist_directory):
24
  """Index documents into vector store"""
25
+ vector_store = self.create_vector_store(collection_name, presist_directory)
26
  vector_store.add_documents(documents=documents)
27
+ return vector_store
src/llm/__init__.py ADDED
File without changes
src/llm/llm_interface.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_groq import ChatGroq
3
+
4
+ llm_groq = ChatGroq(
5
+ model="llama3-8b-8192",
6
+ temperature=0.1,
7
+ api_key=os.getenv("GROQ_API_KEY"),
8
+ max_retries=3,
9
+ streaming=True,
10
+ )
src/state.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Sequence
2
+ from typing_extensions import TypedDict
3
+ from langchain_core.messages import BaseMessage
4
+ from langgraph.graph.message import add_messages
5
+
6
+ class AgentState(TypedDict):
7
+ messages: Annotated[Sequence[BaseMessage], add_messages]
src/tools_retrieval/__init__.py ADDED
File without changes
src/{retriever → tools_retrieval}/retriever.py RENAMED
@@ -1,7 +1,10 @@
1
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
 
2
  from langchain.retrievers import ContextualCompressionRetriever
3
  from langchain.retrievers.document_compressors import FlashrankRerank
4
 
 
 
5
 
6
  class RetrieverManager:
7
  def __init__(self, vector_store):
@@ -31,4 +34,13 @@ class RetrieverManager:
31
  return ContextualCompressionRetriever(
32
  base_compressor=compressor,
33
  base_retriever=base_retriever
 
 
 
 
 
 
 
 
 
34
  )
 
1
+ from langchain_community.retrievers import BM25Retriever
2
+ from langchain.retrievers import EnsembleRetriever
3
  from langchain.retrievers import ContextualCompressionRetriever
4
  from langchain.retrievers.document_compressors import FlashrankRerank
5
 
6
+ from langchain.tools.retriever import create_retriever_tool
7
+
8
 
9
  class RetrieverManager:
10
  def __init__(self, vector_store):
 
34
  return ContextualCompressionRetriever(
35
  base_compressor=compressor,
36
  base_retriever=base_retriever
37
+ )
38
+
39
+ def create_retriever(self, documents):
40
+ base_retriever = self.create_ensemble_retriever(documents)
41
+ compression_retriever = self.create_compression_retriever(base_retriever=base_retriever)
42
+ return create_retriever_tool(
43
+ compression_retriever,
44
+ "retrieve_docs",
45
+ "use tools for search through the user's provided documents and return relevant information about user query.",
46
  )
src/workflow.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import hub
2
+ from langchain_core.output_parsers import StrOutputParser
3
+ from langchain_core.messages import HumanMessage
4
+ from typing import Literal
5
+ from pydantic import BaseModel, Field
6
+ from langchain_core.prompts import PromptTemplate
7
+
8
+ from langgraph.graph import END, StateGraph, START
9
+ from langgraph.prebuilt import ToolNode, tools_condition
10
+ from .state import AgentState
11
+ from src.llm.llm_interface import llm_groq
12
+
13
+
14
+ class GradeDocs(BaseModel):
15
+ binary_score: str = Field(description="Relevance score 'yes' or 'no'")
16
+
17
+
18
+
19
+ class RAGWorkflow:
20
+ def __init__(self, retriever_tool):
21
+ self.workflow = StateGraph(AgentState)
22
+ self.tools = [retriever_tool]
23
+ self.retrieve = ToolNode([retriever_tool])
24
+ self._setup_nodes()
25
+ self._setup_edges()
26
+
27
+ def _setup_nodes(self):
28
+ self.workflow.add_node("agent", self._agent_node)
29
+ self.workflow.add_node("retrieve", self.retrieve)
30
+ self.workflow.add_node("generate", self._generator_node)
31
+
32
+ self.workflow.add_node("rewrite", self._rewrite_node)
33
+
34
+ def _setup_edges(self):
35
+ self.workflow.add_edge(START, "agent")
36
+ self.workflow.add_conditional_edges(
37
+ "agent",
38
+ tools_condition,
39
+ {
40
+ "tools": "retrieve",
41
+ END: END
42
+ }
43
+ )
44
+
45
+ self.workflow.add_conditional_edges(
46
+ "retrieve",
47
+ self._grade_docs,
48
+ )
49
+ self.workflow.add_edge("generate", END)
50
+ self.workflow.add_edge("rewrite", "agent")
51
+
52
+ def compile(self):
53
+ return self.workflow.compile()
54
+
55
+
56
+ def _agent_node(self, state):
57
+ print("---CALL AGENT---")
58
+ messages = state["messages"]
59
+
60
+ model = llm_groq.bind_tools(self.tools)
61
+ response = model.invoke(messages[0].content)
62
+ return {"messages": [response]}
63
+
64
+
65
+ def _generator_node(self, state):
66
+ print("---GENERATE---")
67
+ messages = state["messages"]
68
+ question = messages[0].content
69
+ docs = messages[-1].content
70
+
71
+ prompt = hub.pull("rlm/rag-prompt")
72
+ rag_chain = prompt | llm_groq | StrOutputParser()
73
+
74
+ response = rag_chain.invoke({"context": docs, "question": question})
75
+ return {"messages": [response]}
76
+
77
+
78
+
79
+ def _rewrite_node(self, state):
80
+ print("---REWRITE---")
81
+ messages = state["messages"]
82
+ question = messages[0].content
83
+
84
+ msg = [
85
+ HumanMessage(
86
+ content=f""" \n
87
+ Look at the input and try to reason about the underlying semantic intent / meaning. \n
88
+ Here is the initial question:
89
+ \n ------- \n
90
+ {question}
91
+ \n ------- \n
92
+ Formulate an improved question: """,
93
+ )
94
+ ]
95
+
96
+ response = llm_groq.invoke(msg)
97
+ return {"messages": [response]}
98
+
99
+
100
+
101
+ def _grade_docs(self, state):
102
+ print("---CHECK RELEVANCE---")
103
+
104
+ llm_with_tool = llm_groq.with_structured_output(GradeDocs)
105
+
106
+ prompt = PromptTemplate(
107
+ template="""You are a grader assessing relevance of a retrieved document to a user question. \n
108
+ Here is the retrieved document: \n\n {context} \n\n
109
+ Here is the user question: {question} \n
110
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
111
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
112
+ input_variables=["context", "question"],
113
+ )
114
+
115
+ chain = prompt | llm_with_tool
116
+
117
+ messages = state["messages"]
118
+ question = messages[0].content
119
+ docs = messages[-1].content
120
+
121
+ scored_result = chain.invoke({"question": question, "context": docs})
122
+
123
+ if scored_result.binary_score == "yes":
124
+ print("---DECISION: DOCS RELEVANT---")
125
+ return "generate"
126
+ print("---DECISION: DOCS NOT RELEVANT---")
127
+ return "rewrite"
128
+