File size: 20,509 Bytes
b3f819d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
from dotenv import load_dotenv
import os
import streamlit as st
from langchain_aws import BedrockEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
from typing import List, Dict, Any

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph, END

from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langgraph.graph import MessagesState
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, AIMessage, ToolMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_milvus import Milvus
from utils import extract_text_from_content
from logging_config import setup_logger
from load_vector_db import init_vector_db
from logging_config import setup_logger
import time
from pydantic import BaseModel, Field

logger = setup_logger(__name__)


def init_graph():
    """Initialize the app components and return them."""
    with st.spinner("Initializing PDF chat application..."):
        embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")

        vector_store, compression_retriever = init_vector_db(embeddings)


        # Data model

        llm = init_chat_model(
            "anthropic.claude-3-5-sonnet-20240620-v1:0", 
            model_provider="bedrock_converse",
            temperature=0
        )


        class State(MessagesState):
            user_query: str = ""
            query_validation: Dict[str, Any] = {}
            retrieved_docs: List[Document] = []
            generate_response: bool = False
            generate_sample_questions: bool = False
            sample_questions : str = ""
            final_response : str = ""

        class ValidateQuery(BaseModel):
            """Binary score for question validation."""

            valid_question: bool = Field(
                description="Question is clear and answerable, 'true' or 'false'"
            )
            response: str = Field(
                description="Explanation of the question's validity and suggestions for improvement."
            )


        def validate_query(state: State):
            """Validate if the question is clear and answerable."""
            start = time.time()
            
            # Get the latest human message
            human_messages = [msg for msg in state["messages"] if msg.type == "human"]
            # logger.info(f"human_messages : {human_messages}")
            if not human_messages:
                return {"query_validation": {"valid_question": False, "response": ["No question found to validate."]}}
            
            latest_question = human_messages[-1].content
            # latest_question = human_messages[-1]
            
            validation_prompt = f"""
        You are a question validation assistant. Analyze the following question and determine if it is:
        1. Clear and specific (not vague or ambiguous)
        2. Focused on a single topic or closely related topics (not trying to address too many different things)
        3. Answerable without making assumptions
        4. Concise and well-structured

        # "If the query is single word or phrase, ask the user to provide a complete question."
        # "If the query is not clear, ask for clarification."
        # "If the query is not a complete question, ask the user to provide a complete question and provide some sample questions."
        # "If the query contains multiple questions, answer only the first question and ask the user to ask the next question."
        # "If the query contains complex or compound questions, break them down into simpler parts and answer each part separately."
        # "If the query is not related to the given knowledge source, mention that you can only answer from the knowledge base."

        Respond with a JSON object containing:
        - "valid_question": boolean (true if question is valid, false if not)
        - "response": Explain about the validity of the question and suggest improvements. But dont give any example questions.


        Examples of invalid questions:
        - "Tell me everything about X" (too broad)
        - "What about A, B, C, and also how does D relate to E?" (too many topics)
        - "Why is it better?" (vague, missing context)
        - Questions that would require assumptions about unstated context

        Examples of valid questions:
        - "What is the definition of X?"
        - "How does process A work?"
        - "What are the main benefits of technique B?"
        """

            try:
                structured_llm_grader = llm.with_structured_output(ValidateQuery)

                user_prompt = [f"Question to analyze: {latest_question}"]

                prompt = [SystemMessage(validation_prompt)] + user_prompt
                validation_response = structured_llm_grader.invoke(prompt)    
                
                end = time.time()
                logger.info(f"Time taken for question validation: {end - start} seconds")
                logger.info(f"Question validation result: {validation_response}")
                # logger.info(f"type(validation_response) : {type(validation_response)}")

                # logger.info(f"validate_query model_ump : {validation_response.dict()}")
                
                return {"query_validation": validation_response, 
                        "user_query": latest_question}
                
            except Exception as e:
                logger.info(f"Error in question validation: {e}")
                # Default to valid if validation fails
                return {"query_validation": {"valid_question": False, "response": "Error in Question validation"},
                        "user_query": latest_question}


        def respond_or_call_retrieve_tool(state: State):
            """Generate tool call for retrieval or respond."""
            # logger.info(f"state['messages'] : {state["messages"]}")
            start = time.time()
            # logger.info(f"state['messages'] : {state['messages']}")
            valid_messages = [
                    msg for msg in state["messages"] 
                    if msg.content
                ]
                
            if not valid_messages:
                return {"messages": []}
            llm_with_tools = llm.bind_tools([retrieve_tool])
            response = llm_with_tools.invoke(state["messages"])
            end = time.time()
            logger.info(f"Time taken for query_or_respond_fn LLM invocation: {end - start} seconds")
            # MessagesState appends messages to state instead of overwriting

            
            return {"messages": [response]}


        @tool(response_format="content_and_artifact")
        def retrieve_tool(query: str):
            """Retrieve information related to a query."""
            start = time.time()
            # retrieved_docs = vector_store.similarity_search(query, k=50)
            retrieved_docs = compression_retriever.invoke(input = query,k=30)
            serialized = "\n\n".join(
                (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
                for doc in retrieved_docs
            )
            end = time.time()
            logger.info(f"Time taken for vectordb retrieval: {end - start} seconds")
            logger.info(f"Retrieved {len(retrieved_docs)} documents for query: {query}")
            # logger.info(f"retrieved_docs : {retrieved_docs}")
            # logger.info(f"retrieved_docs : {retrieved_docs}")
            return serialized, retrieved_docs


        def extract_retrieved_docs(state: State):
            """Extract retrieved documents from tool messages and update state."""
            retrieved_docs = []
            logger.info(f"Entered extract_retrieved_docs")
            
            # Look through messages for tool results
            for message in state["messages"]:
                if isinstance(message, ToolMessage) and message.name == "retrieve_tool":
                    try:
                        tool_result = message.artifact
                        retrieved_docs.extend(tool_result)  # Append the retrieved documents
            
                    except Exception as e:
                        logger.info(f"Error parsing tool result: {message.content}")
                        logger.info(f"Exception: {e}")
                        continue
            
            return {"retrieved_docs": retrieved_docs}


        def wait_for_query_validation_and_retrieved_docs(state: State):
            # logger.info(f"state : {state}")
            logger.info(f"Entered wait_for_query_validation_and_retrieved_docs")
            logger.info(f" state.keys() : {state.keys()}")    
            
            logger.info(f"query_validation status :  {"query_validation" in state}")
            logger.info(f"retrieved_docs status : {"retrieved_docs" in state}")
            logger.info(f"state['query_validation'] : {state.get('query_validation')}")

            response = {}

            
            if "query_validation"  in state  and "retrieved_docs" in state:
                logger.info(f"Both question validation and retrieved documents are present in state.")
                logger.info(f"state['query_validation'] : {state['query_validation']}")

                if state["query_validation"].valid_question:
                    response = {
                        "generate_sample_questions": False,
                        "generate_response": True,
                    }
                else: 
                    # If question is not valid, generate sample questions
                    response = {
                        "generate_sample_questions": True,
                        "generate_response": True,
                    }
                    

            else:
                # Return empty dict → set flag False
                response = { "generate_response": False}

            logger.info(f"wait_for_query_validation_and_retrieved_docs response : {response}")


            return response

        def generate_questions(state: State):
            """Generate sample questions based on the retrieved documents."""
            start = time.time()

            # Combine the retrieved documents into a single response
            combined_response = "\n\n".join(
                f"Source: {doc.metadata}\nContent: {doc.page_content}"
                for doc in state["retrieved_docs"]
            )
            
            question_generation_prompt = f"""
        # Question Nudging Prompt

        You are a helpful assistant that generates better question suggestions for users based on available document content.

        ## Context
        A user has asked a question that may be unclear, too broad, or not well-phrased. You have access to relevant documents retrieved from a vector database. Your task is to generate 3-5 sample questions that:

        1. Are directly answerable using the provided document content
        2. Guide the user toward more specific and actionable queries
        3. Help clarify what the user might actually be looking for
        4. Are phrased clearly and concisely

        ## Input Format
        **User Question:** {state["user_query"]}

        **Retrieved Documents:** 
        {combined_response}

        **Question Validity Status:** {state["query_validation"].valid_question}

        **Question validity justification:** {state["query_validation"].response}


        ## Instructions
        1. Analyze the user's original question and the retrieved document content
        2. Identify key topics, concepts, and specific information available in the documents
        3. Generate 3-5 sample questions that:
        - Are more specific than the original question
        - Can be fully answered using only the provided document content
        - Cover different aspects or angles related to the user's intent
        - Use clear, direct language
        - Avoid assumptions not supported by the documents
        4. During response generation, ad justification of question validity which is provided to you in the state["query_validation"].response

        ## Output Format
        {state["query_validation"].response}

        Based on the available information, here are some more specific questions you might want to ask:

        1. [Specific question 1 based on document content]
        2. [Specific question 2 based on document content]  
        3. [Specific question 3 based on document content]
        4. [Specific question 4 based on document content] (if applicable)
        5. [Specific question 5 based on document content] (if applicable)

        ## Important Guidelines
        - Only suggest questions that can be answered using the provided documents
        - Do not make up information not present in the documents
        - Focus on practical, actionable questions
        - Vary the scope and angle of suggested questions
        - Keep questions concise and clear
        - If the original question is valid but could be more specific, acknowledge this in your suggestions
            """

            response = llm.invoke(question_generation_prompt)

            end = time.time()
            logger.info(f"Time taken for generate_fn : {end - start} seconds")
            logger.info(f"Time taken for generate_fn LLM invocation: {end - start} seconds")
            

            return {"messages": [response], "sample_questions":response.content}



        def generate_answer_to_query(state: State):

            logger.info(f"Entered generate_answer_to_query")

            """Generate the final response based on retrieved documents."""
            start = time.time()

            logger.info(f"state.keys() : {state.keys()}")

            if not state["retrieved_docs"]:
                logger.info("No relevant documents found.")
                return {"messages": ["No relevant documents found."]}

            """Generate answer."""
            # Get generated ToolMessages
            start = time.time()

            sources_text = ""

            # logger.info(f"tool_messages {tool_messages}")

            tool_messages = state["retrieved_docs"]
            for artifact in tool_messages:
                page_label = artifact.metadata.get('page_label')
                page = artifact.metadata.get('page')
                source = artifact.metadata.get('source')

                sources_text += f"Source: {source}, Page: {page}, Page Label: {page_label}\n"

                # logger.info(source, page, page_label)
            # logger.info(f"sources_text { sources_text}")
            logger.info(f"sources_text {sources_text}")

            docs_content = "\n\n".join(doc.page_content for doc in tool_messages)
            system_message_content = (
                "You are an assistant for question-answering tasks."
                "Use the following pieces of retrieved context to answer the question."
                "This is your only source of knowledge."
                "If you don't know the answer, say that you don't know and STOP - do not provide related information."
                "You are not allowed to make up answers."
                "You are not allowed to use any external knowledge."
                "You are not allowed to make assumptions."
                "If the query is not clearly and directly addressed in the knowledge source, simply state that you don't have enough information and DO NOT elaborate with tangentially related content."
                "Keep your answers strictly limited to information that directly answers the user's specific question."
                "When information is insufficient, acknowledge this limitation in one sentence without expanding into related topics."
                "Keep your answers accurate and concise to the source content."
                "\n\n"
                f"{docs_content}"

            )
            conversation_messages = [
                message
                for message in state["messages"]
                if message.type in ("human", "system")
                or (message.type == "ai" and not message.tool_calls)
            ]
            prompt = [SystemMessage(system_message_content)] + conversation_messages

            # Run
            start_llm = time.time()
            response = llm.invoke(prompt)
            # return {"messages": [response]}

            end = time.time()
            logger.info(f"Time taken for generate_fn : {end - start} seconds")
            logger.info(f"Time taken for generate_fn LLM invocation: {end - start_llm} seconds")
            

            return {"messages": [response], "final_response": response.content}
            
            
        def generate_flat_response(state:State):
            logger.info(f"Generateing flat response")

            messages = AIMessage("Unable to generate response. Please try again later.")
            
            return {"messages": [messages], "final_response": messages}



        def final_response_router(state : State):
            """Route to the appropriate response generation function."""
            logger.info(f"Entered final_response_router")
            logger.info(f"state.keys() : {state.keys()}")
            logger.info(f"state.get('generate_response'): {state.get("generate_response")}")
            logger.info(f"state.get('generate_sample_questions'): {state.get("generate_sample_questions")}")
            logger.info(f"state.get('generate_response'): {state.get("generate_response")}")

            if state.get("generate_response"):
                logger.info(f"generate_response is True")
                if state.get("generate_sample_questions"):
                    return "generate_questions"
                else:
                    return "generate_answer_to_query"
            else:
                logger.info(f"generate_response is False")
                return END
                # return "generate_flat_response"


        tools_node = ToolNode(tools=[retrieve_tool])

        graph_builder = StateGraph(State)

        graph_builder.add_node(node="validate_query", action=validate_query)
        graph_builder.add_node(node="respond_or_call_retrieve_tool", action=respond_or_call_retrieve_tool)
        graph_builder.add_node(node="tools", action=tools_node)
        graph_builder.add_node(node="extract_retrieved_docs", action=extract_retrieved_docs)
        graph_builder.add_node(node="wait_for_query_validation_and_retrieved_docs", action=wait_for_query_validation_and_retrieved_docs)
        graph_builder.add_node(node="generate_questions", action=generate_questions)
        graph_builder.add_node(node="generate_answer_to_query", action=generate_answer_to_query)
        graph_builder.add_node(node="generate_flat_response", action=generate_flat_response)

        graph_builder.set_entry_point(key="validate_query")
        graph_builder.set_entry_point(key="respond_or_call_retrieve_tool")

        graph_builder.add_conditional_edges(
            source = "respond_or_call_retrieve_tool",
            path = tools_condition,
            path_map= {END:END, "tools": "tools"}
        )

        graph_builder.add_edge(start_key="validate_query", end_key="wait_for_query_validation_and_retrieved_docs")
        graph_builder.add_edge(start_key="respond_or_call_retrieve_tool", end_key="tools")
        graph_builder.add_edge(start_key="tools", end_key="extract_retrieved_docs")
        graph_builder.add_edge(start_key="extract_retrieved_docs", end_key="wait_for_query_validation_and_retrieved_docs")



        graph_builder.add_conditional_edges(
            "wait_for_query_validation_and_retrieved_docs",
            final_response_router,
            {
                "generate_questions": "generate_questions",
                "generate_answer_to_query": "generate_answer_to_query",
                "generate_flat_response": "generate_flat_response",
                END: END
            }
        )

        graph_builder.add_edge(start_key="generate_questions", end_key=END)
        graph_builder.add_edge(start_key="generate_answer_to_query", end_key=END)
        graph_builder.add_edge(start_key="generate_flat_response", end_key=END)


        graph = graph_builder.compile()

        
        st.success("Initialization complete!")
        return {"graph": graph}