vamsidharmuthireddy commited on
Commit
3cfb95f
·
verified ·
1 Parent(s): 4c85569

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dockerfile +6 -4
  2. app.py +36 -233
  3. graph.py +182 -0
  4. load_vector_db.py +105 -0
  5. logging_config.py +47 -0
  6. requirements.txt +1 -0
  7. utils.py +25 -0
Dockerfile CHANGED
@@ -5,8 +5,6 @@ ENV PYTHONDONTWRITEBYTECODE=1
5
  ENV PYTHONUNBUFFERED=1
6
  ENV STREAMLIT_HOME=/app/.streamlit
7
 
8
- ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
9
-
10
  RUN useradd -m -u 1000 appuser
11
 
12
  # Set working directory
@@ -31,9 +29,13 @@ COPY requirements.txt .
31
  RUN pip install --no-cache-dir --upgrade pip && pip install -r requirements.txt
32
 
33
  # Copy app files
34
- COPY app.py .
35
  COPY docs .
36
- COPY vectordb_milvus.db .
 
 
 
 
 
37
 
38
  # Create required directories and fix permissions
39
  RUN mkdir -p $STREAMLIT_HOME && \
 
5
  ENV PYTHONUNBUFFERED=1
6
  ENV STREAMLIT_HOME=/app/.streamlit
7
 
 
 
8
  RUN useradd -m -u 1000 appuser
9
 
10
  # Set working directory
 
29
  RUN pip install --no-cache-dir --upgrade pip && pip install -r requirements.txt
30
 
31
  # Copy app files
 
32
  COPY docs .
33
+ COPY db .
34
+ COPY logging_config.py .
35
+ COPY load_vector_db.py .
36
+ COPY app.py .
37
+ COPY graph.py .
38
+ COPY utils.py .
39
 
40
  # Create required directories and fix permissions
41
  RUN mkdir -p $STREAMLIT_HOME && \
app.py CHANGED
@@ -16,16 +16,25 @@ from langchain_core.tools import tool
16
  from langchain_core.messages import SystemMessage
17
  from langgraph.prebuilt import ToolNode, tools_condition
18
  from langchain_milvus import Milvus
 
 
 
 
 
 
 
19
 
20
 
21
  # Load environment variables
22
- load_dotenv()
23
 
24
  # Set AWS credentials from environment variables
25
  os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get("aws_access_key_id")
26
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get("aws_secret_access_key")
27
  os.environ["AWS_SESSION_TOKEN"] = os.environ.get("aws_session_token")
28
  os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION")
 
 
29
 
30
  # Initialize session state variables if they don't exist
31
  if "messages" not in st.session_state:
@@ -33,235 +42,7 @@ if "messages" not in st.session_state:
33
  if "initialized" not in st.session_state:
34
  st.session_state.initialized = False
35
 
36
- def init_vector_db(embeddings):
37
- # Initialize vector store
38
- URI = "./vectordb_milvus.db"
39
- collection_name = "my_collection"
40
-
41
-
42
- # Check if the collection already exists
43
- try:
44
- # if os.path.exists(URI):
45
- # st.info("Found existing Milvus db.")
46
-
47
- # First try to connect to existing collection
48
- st.info("Checking for existing Milvus db...")
49
- vector_store = Milvus(
50
- embedding_function=embeddings,
51
- connection_args={"uri": URI},
52
- auto_id=True,
53
- collection_name=collection_name,
54
- index_params={"index_type": "FLAT", "metric_type": "COSINE"},
55
- )
56
-
57
- results = vector_store.similarity_search("test query", k=1)
58
-
59
- if len(results) > 0:
60
- st.success("Document data found in existing collection.")
61
- documents_loaded = True
62
- else:
63
- st.info("Collection exists but might be empty. Will check for documents.")
64
- documents_loaded = False
65
-
66
- except Exception as e:
67
- st.info("Creating new Milvus collection...")
68
- vector_store = Milvus(
69
- embedding_function=embeddings,
70
- connection_args={"uri": URI},
71
- auto_id=True,
72
- collection_name=collection_name,
73
- index_params={"index_type": "FLAT", "metric_type": "COSINE"},
74
- )
75
- documents_loaded = False
76
-
77
- # Load documents if needed
78
- if not documents_loaded:
79
- folder_path = "docs"
80
- loader = DirectoryLoader(
81
- folder_path,
82
- glob="**/*.pdf",
83
- loader_cls=PyPDFLoader
84
- )
85
-
86
- try:
87
- documents = loader.load()
88
- st.info(f"Loaded {len(documents)} PDF pages.")
89
-
90
- if len(documents) > 0:
91
- # Split documents
92
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
93
- all_splits = text_splitter.split_documents(documents)
94
- st.info(f"Total Document splits: {len(all_splits)}")
95
-
96
- # Add documents to vector store
97
- _ = vector_store.add_documents(documents=all_splits)
98
- st.success("Documents added to vector store.")
99
- else:
100
- st.warning("No PDF documents found in the 'docs' folder.")
101
- except Exception as e:
102
- st.error(f"Error loading documents: {e}")
103
-
104
- return vector_store
105
-
106
-
107
-
108
- def init_app():
109
- """Initialize the app components and return them."""
110
- with st.spinner("Initializing PDF chat application..."):
111
- # Initialize LLM
112
- llm = init_chat_model(
113
- "anthropic.claude-3-5-sonnet-20240620-v1:0",
114
- model_provider="bedrock_converse",
115
- temperature=0
116
- )
117
 
118
- # Initialize embeddings
119
- embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
120
-
121
- vector_store = init_vector_db(embeddings)
122
-
123
- class State(MessagesState):
124
- context: List[Document]
125
-
126
- # Create a retrieval tool that captures the vector_store
127
- @tool(response_format="content_and_artifact")
128
- def retrieve_tool(query: str):
129
- """Retrieve information related to a query."""
130
- retrieved_docs = vector_store.similarity_search(query, k=5)
131
- serialized = "\n\n".join(
132
- (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
133
- for doc in retrieved_docs
134
- )
135
- print(f"retrieved_docs : {retrieved_docs}")
136
- return serialized, retrieved_docs
137
-
138
- # Create the LLM tool-calling function with direct reference to llm
139
- def query_or_respond_fn(state: State):
140
- """Generate tool call for retrieval or respond."""
141
- print(f"state['messages'] : {state["messages"]}")
142
- valid_messages = [
143
- msg for msg in state["messages"]
144
- if msg.content
145
- ]
146
-
147
- if not valid_messages:
148
- return {"messages": []}
149
- llm_with_tools = llm.bind_tools([retrieve_tool])
150
- response = llm_with_tools.invoke(state["messages"])
151
- # MessagesState appends messages to state instead of overwriting
152
- return {"messages": [response]}
153
-
154
- # Create the generate function with direct reference to llm
155
- def generate_fn(state: State):
156
- """Generate answer."""
157
- # Get generated ToolMessages
158
- recent_tool_messages = []
159
- for message in reversed(state["messages"]):
160
- if message.type == "tool":
161
- recent_tool_messages.append(message)
162
- else:
163
- break
164
- tool_messages = recent_tool_messages[::-1]
165
-
166
- # Format into prompt
167
- sources_text = ""
168
- # print(f"tool_messages { tool_messages}")
169
- print(f"tool_messages { len(tool_messages)}")
170
-
171
- tool_messages_latest = tool_messages[0]
172
- for artifact in tool_messages_latest.artifact:
173
- # artifact = i.artifact
174
- page_label = artifact.metadata.get('page_label')
175
- page = artifact.metadata.get('page')
176
- source = artifact.metadata.get('source')
177
-
178
- sources_text += f"Source: {source}, Page: {page}, Page Label: {page_label}\n"
179
-
180
- # print(source, page, page_label)
181
- print(f"sources_text { sources_text}")
182
-
183
- docs_content = "\n\n".join(doc.content for doc in tool_messages)
184
- system_message_content = (
185
- "You are an assistant for question-answering tasks."
186
- "Use the following pieces of retrieved context to answer the question."
187
- "This is your only source of knowledge."
188
- "If you don't know the answer, say that you don't know and STOP - do not provide related information."
189
- "You are not allowed to make up answers."
190
- "You are not allowed to use any external knowledge."
191
- "You are not allowed to make assumptions."
192
- "If the query is not clearly and directly addressed in the knowledge source, simply state that you don't have enough information and DO NOT elaborate with tangentially related content."
193
- "Keep your answers strictly limited to information that directly answers the user's specific question."
194
- "When information is insufficient, acknowledge this limitation in one sentence without expanding into related topics."
195
- "If the query is not clear, ask for clarification."
196
- "If the query is not related to the given knowledge source, mention that you can only answer from the knowledge base."
197
- "Keep your answers accurate and concise to the source content."
198
- "\n\n"
199
- f"{docs_content}"
200
-
201
- # "Provide the source of the answer like this format at the end of the response: 'Page: Page Number, Source: Source Name' There could be multiple sources, so adjust the response accordingly. Each new source reference should be listed on a new line following this format very strictly. "
202
- # """####Example: This format is **strictly required**. Do not combine multiple sources in the same line. No of lines and sources acn be dynamic.
203
- # Page: 1, Source: Source 1
204
- # Page: 2, Source: Source 2
205
- # Page: 3, Source: Source 3
206
- # """
207
- # f"Paste this content as is {sources_text}"
208
-
209
- )
210
- conversation_messages = [
211
- message
212
- for message in state["messages"]
213
- if message.type in ("human", "system")
214
- or (message.type == "ai" and not message.tool_calls)
215
- ]
216
- prompt = [SystemMessage(system_message_content)] + conversation_messages
217
-
218
- # Run
219
- response = llm.invoke(prompt)
220
- # return {"messages": [response]}
221
- context = []
222
- for tool_message in tool_messages:
223
- context.extend(tool_message.artifact)
224
- return {"messages": [response], "context": context}
225
-
226
- # Execute the retrieval
227
- tools_node = ToolNode([retrieve_tool])
228
-
229
- # Build the graph
230
- graph_builder = StateGraph(MessagesState)
231
- graph_builder.add_node("query_or_respond", query_or_respond_fn)
232
- graph_builder.add_node("tools", tools_node)
233
- graph_builder.add_node("generate", generate_fn)
234
- graph_builder.set_entry_point("query_or_respond")
235
- graph_builder.add_conditional_edges(
236
- "query_or_respond",
237
- tools_condition,
238
- {END: END, "tools": "tools"},
239
- )
240
- graph_builder.add_edge("tools", "generate")
241
- graph_builder.add_edge("generate", END)
242
- graph = graph_builder.compile()
243
-
244
- st.success("Initialization complete!")
245
- return {"graph": graph}
246
-
247
- def extract_text_from_content(content):
248
- """Extract text from various message content formats."""
249
- if isinstance(content, str):
250
- return content
251
- elif isinstance(content, list):
252
- # Handle list of text items or dictionaries
253
- text_parts = []
254
- for item in content:
255
- if isinstance(item, dict):
256
- # Extract text from dictionary format
257
- if 'text' in item:
258
- text_parts.append(item['text'])
259
- elif isinstance(item, str):
260
- text_parts.append(item)
261
- return ''.join(text_parts)
262
- else:
263
- # Fallback for any other format
264
- return str(content)
265
 
266
  def run_graph(graph, input_message: str):
267
  """Run the graph with the input message."""
@@ -286,6 +67,8 @@ def run_graph(graph, input_message: str):
286
  response_chunks = []
287
  values = []
288
 
 
 
289
  for mode, mode_chunk in graph.stream(
290
  input_message_formatted,
291
  stream_mode=["messages", "values"],
@@ -295,8 +78,24 @@ def run_graph(graph, input_message: str):
295
  elif mode == "messages":
296
  message, metadata = mode_chunk
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  if metadata["langgraph_node"] == "generate":
299
  if hasattr(message, 'content'):
 
 
 
300
  content = message.content
301
  # Extract text depending on content format
302
  chunk_text = extract_text_from_content(content)
@@ -306,6 +105,9 @@ def run_graph(graph, input_message: str):
306
  yield chunk_text, values
307
  full_response = ''.join(response_chunks)
308
 
 
 
 
309
  # print(f"Full text: {full_response}")
310
  # print(f"full values: {values}")
311
  st.conversation_history.append({
@@ -320,7 +122,7 @@ st.title("PDF Question-Answering Chat")
320
  # Initialize the app if not already done
321
  if not st.session_state.initialized:
322
  try:
323
- app_components = init_app()
324
  st.session_state.app_components = app_components
325
  st.session_state.initialized = True
326
  st.conversation_history = []
@@ -353,13 +155,14 @@ if prompt := st.chat_input("Ask a question about your PDFs"):
353
  values = {}
354
  for chunk, values in run_graph(st.session_state.app_components["graph"], prompt):
355
  if chunk: # Only process non-empty chunks
 
356
  full_response += chunk
357
  message_placeholder.markdown(full_response + "▌")
358
 
359
  values = values[-1]
360
  # print(f"values: {values}")
361
- print(f"values keys: {values.keys()}")
362
- print(f"'context' in values: { 'context' in values }")
363
  if 'context' in values:
364
  pages_dict = {}
365
 
@@ -388,7 +191,7 @@ if prompt := st.chat_input("Ask a question about your PDFs"):
388
 
389
 
390
  message_placeholder.markdown(full_response)
391
- print(f"Full response: {full_response}")
392
 
393
  # Add assistant response to chat history
394
  st.session_state.messages.append({"role": "assistant", "content": full_response})
 
16
  from langchain_core.messages import SystemMessage
17
  from langgraph.prebuilt import ToolNode, tools_condition
18
  from langchain_milvus import Milvus
19
+ from utils import extract_text_from_content
20
+ from logging_config import setup_logger
21
+ from load_vector_db import init_vector_db
22
+ from graph import init_graph
23
+ import time
24
+
25
+ logger = setup_logger(__name__)
26
 
27
 
28
  # Load environment variables
29
+ load_dotenv(override=True)
30
 
31
  # Set AWS credentials from environment variables
32
  os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get("aws_access_key_id")
33
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get("aws_secret_access_key")
34
  os.environ["AWS_SESSION_TOKEN"] = os.environ.get("aws_session_token")
35
  os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION")
36
+ print(os.environ["AWS_ACCESS_KEY_ID"])
37
+
38
 
39
  # Initialize session state variables if they don't exist
40
  if "messages" not in st.session_state:
 
42
  if "initialized" not in st.session_state:
43
  st.session_state.initialized = False
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def run_graph(graph, input_message: str):
48
  """Run the graph with the input message."""
 
67
  response_chunks = []
68
  values = []
69
 
70
+ start = time.time()
71
+ time_to_start_streaming = None
72
  for mode, mode_chunk in graph.stream(
73
  input_message_formatted,
74
  stream_mode=["messages", "values"],
 
78
  elif mode == "messages":
79
  message, metadata = mode_chunk
80
 
81
+ # if metadata["langgraph_node"] == "query_or_respond":
82
+ # logger.info(f"message.tool_calls: {message.tool_calls}")
83
+ # if not message.tool_calls:
84
+ # content = message.content
85
+ # logger.info(f"query_or_respond content type: {isinstance(content, str)}")
86
+ # logger.info(f"query_or_respond content: {content}")
87
+ # if isinstance(content, str):
88
+ # chunk_text = content
89
+ # # chunk_text = extract_text_from_content(content)
90
+ # if chunk_text:
91
+ # response_chunks.append(chunk_text)
92
+ # yield chunk_text, values
93
+
94
  if metadata["langgraph_node"] == "generate":
95
  if hasattr(message, 'content'):
96
+ if time_to_start_streaming is None:
97
+ time_to_start_streaming = time.time() - start
98
+ logger.info(f"Time taken to start streaming: {time_to_start_streaming} seconds")
99
  content = message.content
100
  # Extract text depending on content format
101
  chunk_text = extract_text_from_content(content)
 
105
  yield chunk_text, values
106
  full_response = ''.join(response_chunks)
107
 
108
+ logger.info(f"Time taken for complete generation: {time.time() - start} seconds")
109
+
110
+
111
  # print(f"Full text: {full_response}")
112
  # print(f"full values: {values}")
113
  st.conversation_history.append({
 
122
  # Initialize the app if not already done
123
  if not st.session_state.initialized:
124
  try:
125
+ app_components = init_graph()
126
  st.session_state.app_components = app_components
127
  st.session_state.initialized = True
128
  st.conversation_history = []
 
155
  values = {}
156
  for chunk, values in run_graph(st.session_state.app_components["graph"], prompt):
157
  if chunk: # Only process non-empty chunks
158
+ # print(f"Chunk: {chunk}")
159
  full_response += chunk
160
  message_placeholder.markdown(full_response + "▌")
161
 
162
  values = values[-1]
163
  # print(f"values: {values}")
164
+ logger.info(f"values keys: {values.keys()}")
165
+ logger.info(f"'context' in values: { 'context' in values }")
166
  if 'context' in values:
167
  pages_dict = {}
168
 
 
191
 
192
 
193
  message_placeholder.markdown(full_response)
194
+ logger.info(f"Full response: {full_response}")
195
 
196
  # Add assistant response to chat history
197
  st.session_state.messages.append({"role": "assistant", "content": full_response})
graph.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import streamlit as st
4
+ from langchain_aws import BedrockEmbeddings
5
+ from langchain_core.vectorstores import InMemoryVectorStore
6
+ from langchain.chat_models import init_chat_model
7
+ from langchain_core.documents import Document
8
+ from typing_extensions import List, Dict
9
+
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langgraph.graph import START, StateGraph, END
12
+
13
+ from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
14
+ from langgraph.graph import MessagesState
15
+ from langchain_core.tools import tool
16
+ from langchain_core.messages import SystemMessage
17
+ from langgraph.prebuilt import ToolNode, tools_condition
18
+ from langchain_milvus import Milvus
19
+ from utils import extract_text_from_content
20
+ from logging_config import setup_logger
21
+ from load_vector_db import init_vector_db
22
+ from logging_config import setup_logger
23
+ import time
24
+
25
+ logger = setup_logger(__name__)
26
+
27
+
28
+ def init_graph():
29
+ """Initialize the app components and return them."""
30
+ with st.spinner("Initializing PDF chat application..."):
31
+ # Initialize LLM
32
+ llm = init_chat_model(
33
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
34
+ model_provider="bedrock_converse",
35
+ temperature=0
36
+ )
37
+
38
+ # Initialize embeddings
39
+ embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
40
+
41
+ vector_store, compression_retriever = init_vector_db(embeddings)
42
+
43
+ class State(MessagesState):
44
+ context: List[Document]
45
+
46
+ # Create a retrieval tool that captures the vector_store
47
+ @tool(response_format="content_and_artifact")
48
+ def retrieve_tool(query: str):
49
+ """Retrieve information related to a query."""
50
+ start = time.time()
51
+ # retrieved_docs = vector_store.similarity_search(query, k=50)
52
+ retrieved_docs = compression_retriever.invoke(input = query,k=10)
53
+ serialized = "\n\n".join(
54
+ (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
55
+ for doc in retrieved_docs
56
+ )
57
+ end = time.time()
58
+ logger.info(f"Time taken for vectordb retrieval: {end - start} seconds")
59
+ # print(f"retrieved_docs : {retrieved_docs}")
60
+ logger.info(f"retrieved_docs num: {len(retrieved_docs)}")
61
+ logger.info(f"retrieved_docs : {retrieved_docs}")
62
+ return serialized, retrieved_docs
63
+
64
+ # Create the LLM tool-calling function with direct reference to llm
65
+ def query_or_respond_fn(state: State):
66
+ """Generate tool call for retrieval or respond."""
67
+ # print(f"state['messages'] : {state["messages"]}")
68
+ start = time.time()
69
+ logger.info(f"state['messages'] : {state['messages']}")
70
+ valid_messages = [
71
+ msg for msg in state["messages"]
72
+ if msg.content
73
+ ]
74
+
75
+ if not valid_messages:
76
+ return {"messages": []}
77
+ llm_with_tools = llm.bind_tools([retrieve_tool])
78
+ response = llm_with_tools.invoke(state["messages"])
79
+ end = time.time()
80
+ logger.info(f"Time taken for query_or_respond_fn LLM invocation: {end - start} seconds")
81
+ # MessagesState appends messages to state instead of overwriting
82
+ return {"messages": [response]}
83
+
84
+ # Create the generate function with direct reference to llm
85
+ def generate_fn(state: State):
86
+ """Generate answer."""
87
+ # Get generated ToolMessages
88
+ start = time.time()
89
+ recent_tool_messages = []
90
+ for message in reversed(state["messages"]):
91
+ if message.type == "tool":
92
+ recent_tool_messages.append(message)
93
+ else:
94
+ break
95
+ tool_messages = recent_tool_messages[::-1]
96
+
97
+ # Format into prompt
98
+ sources_text = ""
99
+ # print(f"tool_messages { tool_messages}")
100
+ # print(f"tool_messages { len(tool_messages)}")
101
+ logger.info(f"tool_messages {tool_messages}")
102
+
103
+ tool_messages_latest = tool_messages[0]
104
+ for artifact in tool_messages_latest.artifact:
105
+ # artifact = i.artifact
106
+ page_label = artifact.metadata.get('page_label')
107
+ page = artifact.metadata.get('page')
108
+ source = artifact.metadata.get('source')
109
+
110
+ sources_text += f"Source: {source}, Page: {page}, Page Label: {page_label}\n"
111
+
112
+ # print(source, page, page_label)
113
+ # print(f"sources_text { sources_text}")
114
+ logger.info(f"sources_text {sources_text}")
115
+
116
+ docs_content = "\n\n".join(doc.content for doc in tool_messages)
117
+ system_message_content = (
118
+ "You are an assistant for question-answering tasks."
119
+ "Use the following pieces of retrieved context to answer the question."
120
+ "This is your only source of knowledge."
121
+ "If you don't know the answer, say that you don't know and STOP - do not provide related information."
122
+ "You are not allowed to make up answers."
123
+ "You are not allowed to use any external knowledge."
124
+ "You are not allowed to make assumptions."
125
+ "If the query is not clearly and directly addressed in the knowledge source, simply state that you don't have enough information and DO NOT elaborate with tangentially related content."
126
+ "Keep your answers strictly limited to information that directly answers the user's specific question."
127
+ "When information is insufficient, acknowledge this limitation in one sentence without expanding into related topics."
128
+ "If the query is single word or phrase, ask the user to provide a complete question."
129
+ "If the query is not clear, ask for clarification."
130
+ "If the query is not a complete question, ask the user to provide a complete question and provide some sample questions."
131
+ "If the query contains multiple questions, answer only the first question and ask the user to ask the next question."
132
+ "If the query contains complex or compound questions, break them down into simpler parts and answer each part separately."
133
+ "If the query is not related to the given knowledge source, mention that you can only answer from the knowledge base."
134
+ "Keep your answers accurate and concise to the source content."
135
+ "\n\n"
136
+ f"{docs_content}"
137
+
138
+ )
139
+ conversation_messages = [
140
+ message
141
+ for message in state["messages"]
142
+ if message.type in ("human", "system")
143
+ or (message.type == "ai" and not message.tool_calls)
144
+ ]
145
+ prompt = [SystemMessage(system_message_content)] + conversation_messages
146
+
147
+ # Run
148
+ start_llm = time.time()
149
+ response = llm.invoke(prompt)
150
+ # return {"messages": [response]}
151
+ context = []
152
+ for tool_message in tool_messages:
153
+ context.extend(tool_message.artifact)
154
+
155
+ end = time.time()
156
+ logger.info(f"Time taken for generate_fn : {end - start} seconds")
157
+ logger.info(f"Time taken for generate_fn LLM invocation: {end - start_llm} seconds")
158
+
159
+
160
+ return {"messages": [response], "context": context}
161
+
162
+ # Execute the retrieval
163
+ tools_node = ToolNode([retrieve_tool])
164
+
165
+ # Build the graph
166
+ graph_builder = StateGraph(MessagesState)
167
+ graph_builder.add_node("query_or_respond", query_or_respond_fn)
168
+ graph_builder.add_node("tools", tools_node)
169
+ graph_builder.add_node("generate", generate_fn)
170
+ graph_builder.set_entry_point("query_or_respond")
171
+ graph_builder.add_conditional_edges(
172
+ "query_or_respond",
173
+ tools_condition,
174
+ {END: END, "tools": "tools"},
175
+ )
176
+ graph_builder.add_edge("tools", "generate")
177
+ graph_builder.add_edge("generate", END)
178
+ graph = graph_builder.compile()
179
+
180
+ st.success("Initialization complete!")
181
+ return {"graph": graph}
182
+
load_vector_db.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain_aws import BedrockEmbeddings
4
+
5
+ from langchain.chat_models import init_chat_model
6
+ from langchain_core.documents import Document
7
+ from typing_extensions import List, Dict, TypedDict
8
+
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from langgraph.graph import START, StateGraph, END
11
+
12
+ from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
13
+ from langgraph.graph import MessagesState
14
+ from langchain_core.tools import tool
15
+ from langchain_core.messages import SystemMessage
16
+ from langgraph.prebuilt import ToolNode, tools_condition
17
+ from langchain_milvus import Milvus
18
+ from langchain_openai import ChatOpenAI
19
+ from pydantic import BaseModel, Field
20
+ from logging_config import setup_logger
21
+ from flashrank import Ranker
22
+ from langchain_community.document_compressors import FlashrankRerank
23
+ from langchain.retrievers import ContextualCompressionRetriever
24
+
25
+ logger = setup_logger(__name__)
26
+
27
+
28
+ def init_vector_db(embeddings):
29
+ # Initialize vector store
30
+ URI = "db/vectordb_milvus.db"
31
+ collection_name = "my_collection"
32
+
33
+
34
+ # Check if the collection already exists
35
+ try:
36
+
37
+ st.info("Checking for existing Milvus db...")
38
+ vector_store = Milvus(
39
+ embedding_function=embeddings,
40
+ connection_args={"uri": URI},
41
+ auto_id=True,
42
+ collection_name=collection_name,
43
+ index_params={"index_type": "FLAT", "metric_type": "COSINE"},
44
+ )
45
+
46
+ results = vector_store.similarity_search("test query", k=1)
47
+
48
+ if len(results) > 0:
49
+ st.success("Document data found in existing collection.")
50
+ documents_loaded = True
51
+ else:
52
+ st.info("Collection exists but might be empty. Will check for documents.")
53
+ documents_loaded = False
54
+
55
+ except Exception as e:
56
+ st.info("Creating new Milvus collection...")
57
+ vector_store = Milvus(
58
+ embedding_function=embeddings,
59
+ connection_args={"uri": URI},
60
+ auto_id=True,
61
+ collection_name=collection_name,
62
+ index_params={"index_type": "FLAT", "metric_type": "COSINE"},
63
+ )
64
+ documents_loaded = False
65
+
66
+ # Load documents if needed
67
+ if not documents_loaded:
68
+ folder_path = "docs"
69
+ loader = DirectoryLoader(
70
+ folder_path,
71
+ glob="**/*.pdf",
72
+ loader_cls=PyPDFLoader
73
+ )
74
+
75
+ try:
76
+ documents = loader.load()
77
+ st.info(f"Loaded {len(documents)} PDF pages.")
78
+
79
+ if len(documents) > 0:
80
+ # Split documents
81
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
82
+ all_splits = text_splitter.split_documents(documents)
83
+ st.info(f"Total Document splits: {len(all_splits)}")
84
+
85
+ # Add documents to vector store
86
+ _ = vector_store.add_documents(documents=all_splits)
87
+ st.success("Documents added to vector store.")
88
+ else:
89
+ st.warning("No PDF documents found in the 'docs' folder.")
90
+ except Exception as e:
91
+ st.error(f"Error loading documents: {e}")
92
+
93
+
94
+ retriever = vector_store.as_retriever(search_kwargs={"k": 50})
95
+
96
+ ranker_client = Ranker(model_name="ms-marco-MultiBERT-L-12",
97
+ cache_dir="./models")
98
+
99
+ compressor = FlashrankRerank(client=ranker_client, top_n=10)
100
+ compression_retriever = ContextualCompressionRetriever(
101
+ base_compressor=compressor, base_retriever=retriever
102
+ )
103
+
104
+ return vector_store, compression_retriever
105
+
logging_config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # logging_config.py
2
+ import logging
3
+ import os
4
+ from datetime import datetime
5
+ # from config import LOGS_DIR
6
+
7
+ LOGS_DIR = "logs"
8
+
9
+ # Create logs directory if it doesn't exist
10
+
11
+ os.makedirs(LOGS_DIR, exist_ok=True)
12
+
13
+ # Generate filename with timestamp
14
+ log_filename = os.path.join(LOGS_DIR, f"app_{datetime.now().strftime('%Y%m%d')}.log")
15
+
16
+ def setup_logger(name):
17
+ """
18
+ Create a logger with the specified name that writes to both file and console
19
+ """
20
+ logger = logging.getLogger(name)
21
+
22
+ # Only configure if it hasn't been configured yet
23
+ if not logger.handlers:
24
+ logger.setLevel(logging.DEBUG)
25
+
26
+ # Create file handler
27
+ file_handler = logging.FileHandler(log_filename)
28
+ file_handler.setLevel(logging.DEBUG)
29
+ file_formatter = logging.Formatter(
30
+ '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
31
+ datefmt='%Y-%m-%d %H:%M:%S'
32
+ )
33
+ file_handler.setFormatter(file_formatter)
34
+
35
+ # Create console handler
36
+ console_handler = logging.StreamHandler()
37
+ console_handler.setLevel(logging.INFO) # Less verbose for console
38
+ console_formatter = logging.Formatter(
39
+ '%(levelname)s - %(name)s - %(message)s'
40
+ )
41
+ console_handler.setFormatter(console_formatter)
42
+
43
+ # Add handlers to logger
44
+ logger.addHandler(file_handler)
45
+ logger.addHandler(console_handler)
46
+
47
+ return logger
requirements.txt CHANGED
@@ -31,6 +31,7 @@ durationpy==0.9
31
  executing==2.2.0
32
  fastapi==0.115.12
33
  filelock==3.18.0
 
34
  flatbuffers==25.2.10
35
  frozenlist==1.6.0
36
  fsspec==2025.3.2
 
31
  executing==2.2.0
32
  fastapi==0.115.12
33
  filelock==3.18.0
34
+ FlashRank==0.2.10
35
  flatbuffers==25.2.10
36
  frozenlist==1.6.0
37
  fsspec==2025.3.2
utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging_config import setup_logger
2
+
3
+
4
+ logger = setup_logger(__name__)
5
+
6
+
7
+
8
+ def extract_text_from_content(content):
9
+ """Extract text from various message content formats."""
10
+ if isinstance(content, str):
11
+ return content
12
+ elif isinstance(content, list):
13
+ # Handle list of text items or dictionaries
14
+ text_parts = []
15
+ for item in content:
16
+ if isinstance(item, dict):
17
+ # Extract text from dictionary format
18
+ if 'text' in item:
19
+ text_parts.append(item['text'])
20
+ elif isinstance(item, str):
21
+ text_parts.append(item)
22
+ return ''.join(text_parts)
23
+ else:
24
+ # Fallback for any other format
25
+ return str(content)