File size: 10,657 Bytes
70b22b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
Author: Nikhil Nageshwar Inturi (GitHub: @unikill066)
Date: 2025-06-22

Create a langgraph graph and compile it for invocation
"""

# imports
import streamlit as st, warnings, os, logging, sys
from constants import COLLECTION_NAME
from dotenv import load_dotenv
warnings.filterwarnings("ignore")
from typing import Annotated, Literal, Sequence, TypedDict
from langchain import hub
from langchain_core.messages import  HumanMessage
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Field
from pydantic import BaseModel
from langgraph.graph.message import add_messages
from langgraph.prebuilt import tools_condition
from langchain_community.vectorstores import Chroma
from langchain.tools.retriever import create_retriever_tool
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
# load environment variablesx
load_dotenv()

# logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# validate openai api key
openai_api_key = st.secrets["OPENAI_API_KEY"]
if not openai_api_key:
    st.error("OpenAI API key not found in environment variables.")

llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.5, api_key=openai_api_key)
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)

class AgentState(TypedDict):  # agent state across the graph execution
    messages: Annotated[Sequence[BaseMessage], add_messages]

# creating a custom retriever tool for agentic tool use
# refer to bin/retriever.py
vectorstore = Chroma(persist_directory="/Users/discovery/Desktop/agentic-rag/chroma_db",  
    embedding_function=embedding_model,    
    collection_name=COLLECTION_NAME)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})  # k is the number of documents to retrieve
# vectorstore.as_retriever()
# query = ""
# qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", 
# retriever=retriever, return_source_documents=True)

retriever_tool = create_retriever_tool(
    retriever, 
    "retriever",
    """You are a specialized assistant and you have to search and return information about Nikhil from the documents
    Use the `retriever` tool **only** when the query explicitly related to Nikhil or queries about Nikhil.
    For all other queries, respond directly without using this custom `retriever` tool.
    And, for simple queries like 'hi', 'hello', or 'how are you', provide a short humanable response.
    """
)

tools = [retriever_tool, ]  # list of tools - Internet Search CHECK - [x]

# create a tool node
retriever_node = ToolNode([retriever_tool])
class router(BaseModel):
    route: str=Field(description="Route to 'yes' or 'no' based on relevance of query")

def rag_agent(state: AgentState) -> AgentState:
    logger.info("\n - - - RAG Agent Invocation - - -\n")
    messages = state["messages"]
    latest_message = messages[-1]
    query = latest_message.content if hasattr(latest_message, 'content') else str(latest_message)
    logger.info(f"Query received: {query}")
    # use tools for any query - let the LLM and tools_condition decide
    system_message = HumanMessage(content=f"""
    You are a helpful assistant that answers questions about Nikhil. 
    For ANY query about Nikhil (background, experience, education, projects, skills, work, etc.), 
    you MUST use the 'retriever' tool to search for relevant information first.
    For simple greetings like 'hi', 'hello', or 'how are you', respond directly without tools.
    Current query: {query}
    """)
    simple_greetings = ['hi', 'hello', 'hey', 'how are you', 'good morning', 'good afternoon', 'good evening']
    is_greeting = any(greeting.lower() in query.lower() for greeting in simple_greetings) and len(query.split()) <= 3
    if is_greeting:
        logger.info("Simple greeting - responding directly")
        response = llm.invoke([HumanMessage(content="Respond to this greeting in a friendly way: " + query)])
    else:
        logger.info("Using LLM with tools - letting tools_condition decide")
        llm_with_tools = llm.bind_tools(tools)
        enhanced_messages = [system_message] + messages
        response = llm_with_tools.invoke(enhanced_messages)
    logger.info(f"RAG Agent Response type: {type(response)}")
    logger.info(f"RAG Agent Response: {response}")

    if hasattr(response, 'tool_calls') and response.tool_calls:
        logger.info(f"Tool calls detected: {len(response.tool_calls)} tool(s)")
        for i, tool_call in enumerate(response.tool_calls):
            logger.info(f"Tool call {i+1}: {tool_call}")
    else:
        logger.info("No tool calls in response")
    return {"messages": [response]}

def document_quality(state: AgentState) -> Literal["rewrite", "generator"]:
    logger.info("\n - - - Document Quality Invocation - - -\n")
    messages = state["messages"]
    
    if len(messages) < 2:
        logger.info("Not enough messages for quality check - going to rewrite")
        return "rewrite"

    original_query = None
    for msg in messages:
        if isinstance(msg, HumanMessage):
            original_query = msg.content
            break
    
    if not original_query:
        logger.info("No original query found - going to rewrite")
        return "rewrite"
    
    last_message = messages[-1]
    document = last_message.content if hasattr(last_message, 'content') else str(last_message)
    logger.info(f"Checking quality for query: {original_query}")
    logger.info(f"Document snippet: {document[:200]}...")
    llm_with_struct = llm.with_structured_output(router)
    prompt = PromptTemplate(template="""
    You are a helpful assistant checking document relevance.
    Query: {query}
    Document: {context}
    Is this document relevant to answering the query? 
    - If the document contains information that can help answer the query, return 'yes'
    - If the document is not relevant or doesn't contain useful information, return 'no'
    """, input_variables=["context", "query"])
    chain = prompt | llm_with_struct
    response = chain.invoke({"context": document, "query": original_query})
    route_to = response.route.lower()
    logger.info(f"Quality check result: {route_to}")
    if route_to == "yes":
        logger.info("Document is relevant - going to generator")
        return "generator"
    else:
        logger.info("Document is not relevant - going to rewrite")
        return "rewrite"

def generator(state: AgentState) -> AgentState:
    logger.info("\n - - - Generator Invocation - - -\n")
    messages = state["messages"]
    original_query = None
    for msg in messages:
        if isinstance(msg, HumanMessage):
            original_query = msg.content
            break
    last_message = messages[-1]
    document = last_message.content if hasattr(last_message, 'content') else str(last_message)
    logger.info(f"Generating answer for: {original_query}")
    try:
        prompt = hub.pull("rlm/rag-prompt")
        rag_chain = prompt | llm
        response = rag_chain.invoke({"context": document, "question": original_query})
    except Exception as e:
        logger.error(f"Error with hub prompt: {e}")
        # Fallback prompt
        fallback_prompt = PromptTemplate(template="""
        Based on the following context, answer the question:
        Context: {context}
        Question: {question}
        Answer:""", input_variables=["context", "question"])
        rag_chain = fallback_prompt | llm
        response = rag_chain.invoke({"context": document, "question": original_query})
    logger.info(f"Generator Response: {response}")
    return {"messages": [response]}
    
def rewrite(state: AgentState) -> AgentState:
    logger.info("\n - - - Rewrite Invocation - - -\n")
    messages = state["messages"]
    original_query = None
    for msg in messages:
        if isinstance(msg, HumanMessage):
            original_query = msg.content
            break
    if not original_query:
        original_query = "Tell me about Nikhil"
    logger.info(f"Rewriting query: {original_query}")
    rewrite_prompt = PromptTemplate(template="""
    The original query was: {query}
    The retrieval didn't find relevant information. Please rewrite this query to be more specific and likely to find relevant information about Nikhil's background, experience, or qualifications.
    Rewritten query:""", input_variables=["query"])
    chain = rewrite_prompt | llm
    response = chain.invoke({"query": original_query})
    logger.info(f"Rewritten query: {response}")
    rewritten_message = HumanMessage(content=response.content if hasattr(response, 'content') else str(response))
    return {"messages": [rewritten_message]}

# # create a state graph
# graph = StateGraph(AgentState)
# graph.add_node("rag_agent", rag_agent)
# graph.add_node("retriever_node", retriever_node)
# graph.add_node("generator", generator)
# graph.add_node("rewrite", rewrite)
# graph.add_edge(START, "rag_agent")
# graph.add_conditional_edges("rag_agent", tools_condition, {"tools": "retriever_node", END: END})
# graph.add_conditional_edges("retriever_node", document_quality, {"generator": "generator", "rewrite": "rewrite"})
# graph.add_edge("rewrite", "rag_agent")
# graph.add_edge("generator", END)
# app = graph.compile()

def build_rag_state_graph():
    logger.info("\n - - - Building RAG State Graph - - -\n")
    graph = StateGraph(AgentState)  # stategraph definition
    # nodes
    graph.add_node("rag_agent", rag_agent)
    graph.add_node("retriever_node", retriever_node)
    graph.add_node("generator", generator)
    graph.add_node("rewrite", rewrite)
    # edges
    graph.add_edge(START, "rag_agent")
    graph.add_conditional_edges("rag_agent", tools_condition, {"tools": "retriever_node", END: END})
    graph.add_conditional_edges("retriever_node", document_quality, {"generator": "generator", "rewrite": "rewrite"})
    graph.add_edge("rewrite", "rag_agent")
    graph.add_edge("generator", END)
    logger.info("\n - - - RAG State Graph Built - - -\n")
    return graph.compile()

# save compiled graph state to a PNG
def save_mermaid_graph(app, output_path: str = "./graph.png") -> None:
    """Generate the app’s Mermaid diagram and save it as a PNG file."""
    png_bytes = app.get_graph(xray=True).draw_mermaid_png()
    with open(output_path, "wb") as f:
        f.write(png_bytes)

app = build_rag_state_graph()
save_mermaid_graph(app)
logger.info("\n - - - Graph saved to PNG - - -\n")