""" Author: Nikhil Nageshwar Inturi (GitHub: @unikill066) Date: 2025-06-22 Create a langgraph graph and compile it for invocation """ # imports import streamlit as st, warnings, os, logging, sys from constants import COLLECTION_NAME from dotenv import load_dotenv warnings.filterwarnings("ignore") from typing import Annotated, Literal, Sequence, TypedDict from langchain import hub from langchain_core.messages import HumanMessage from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.pydantic_v1 import Field from pydantic import BaseModel from langgraph.graph.message import add_messages from langgraph.prebuilt import tools_condition from langchain_community.vectorstores import Chroma from langchain.tools.retriever import create_retriever_tool from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode from langchain_openai import ChatOpenAI from langchain_openai import OpenAIEmbeddings # load environment variablesx load_dotenv() # logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # validate openai api key openai_api_key = st.secrets["OPENAI_API_KEY"] if not openai_api_key: st.error("OpenAI API key not found in environment variables.") llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.5, api_key=openai_api_key) embedding_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key) class AgentState(TypedDict): # agent state across the graph execution messages: Annotated[Sequence[BaseMessage], add_messages] # creating a custom retriever tool for agentic tool use # refer to bin/retriever.py vectorstore = Chroma(persist_directory="/Users/discovery/Desktop/agentic-rag/chroma_db", embedding_function=embedding_model, collection_name=COLLECTION_NAME) retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) # k is the number of documents to retrieve # vectorstore.as_retriever() # query = "" # qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", # retriever=retriever, return_source_documents=True) retriever_tool = create_retriever_tool( retriever, "retriever", """You are a specialized assistant and you have to search and return information about Nikhil from the documents Use the `retriever` tool **only** when the query explicitly related to Nikhil or queries about Nikhil. For all other queries, respond directly without using this custom `retriever` tool. And, for simple queries like 'hi', 'hello', or 'how are you', provide a short humanable response. """ ) tools = [retriever_tool, ] # list of tools - Internet Search CHECK - [x] # create a tool node retriever_node = ToolNode([retriever_tool]) class router(BaseModel): route: str=Field(description="Route to 'yes' or 'no' based on relevance of query") def rag_agent(state: AgentState) -> AgentState: logger.info("\n - - - RAG Agent Invocation - - -\n") messages = state["messages"] latest_message = messages[-1] query = latest_message.content if hasattr(latest_message, 'content') else str(latest_message) logger.info(f"Query received: {query}") # use tools for any query - let the LLM and tools_condition decide system_message = HumanMessage(content=f""" You are a helpful assistant that answers questions about Nikhil. For ANY query about Nikhil (background, experience, education, projects, skills, work, etc.), you MUST use the 'retriever' tool to search for relevant information first. For simple greetings like 'hi', 'hello', or 'how are you', respond directly without tools. Current query: {query} """) simple_greetings = ['hi', 'hello', 'hey', 'how are you', 'good morning', 'good afternoon', 'good evening'] is_greeting = any(greeting.lower() in query.lower() for greeting in simple_greetings) and len(query.split()) <= 3 if is_greeting: logger.info("Simple greeting - responding directly") response = llm.invoke([HumanMessage(content="Respond to this greeting in a friendly way: " + query)]) else: logger.info("Using LLM with tools - letting tools_condition decide") llm_with_tools = llm.bind_tools(tools) enhanced_messages = [system_message] + messages response = llm_with_tools.invoke(enhanced_messages) logger.info(f"RAG Agent Response type: {type(response)}") logger.info(f"RAG Agent Response: {response}") if hasattr(response, 'tool_calls') and response.tool_calls: logger.info(f"Tool calls detected: {len(response.tool_calls)} tool(s)") for i, tool_call in enumerate(response.tool_calls): logger.info(f"Tool call {i+1}: {tool_call}") else: logger.info("No tool calls in response") return {"messages": [response]} def document_quality(state: AgentState) -> Literal["rewrite", "generator"]: logger.info("\n - - - Document Quality Invocation - - -\n") messages = state["messages"] if len(messages) < 2: logger.info("Not enough messages for quality check - going to rewrite") return "rewrite" original_query = None for msg in messages: if isinstance(msg, HumanMessage): original_query = msg.content break if not original_query: logger.info("No original query found - going to rewrite") return "rewrite" last_message = messages[-1] document = last_message.content if hasattr(last_message, 'content') else str(last_message) logger.info(f"Checking quality for query: {original_query}") logger.info(f"Document snippet: {document[:200]}...") llm_with_struct = llm.with_structured_output(router) prompt = PromptTemplate(template=""" You are a helpful assistant checking document relevance. Query: {query} Document: {context} Is this document relevant to answering the query? - If the document contains information that can help answer the query, return 'yes' - If the document is not relevant or doesn't contain useful information, return 'no' """, input_variables=["context", "query"]) chain = prompt | llm_with_struct response = chain.invoke({"context": document, "query": original_query}) route_to = response.route.lower() logger.info(f"Quality check result: {route_to}") if route_to == "yes": logger.info("Document is relevant - going to generator") return "generator" else: logger.info("Document is not relevant - going to rewrite") return "rewrite" def generator(state: AgentState) -> AgentState: logger.info("\n - - - Generator Invocation - - -\n") messages = state["messages"] original_query = None for msg in messages: if isinstance(msg, HumanMessage): original_query = msg.content break last_message = messages[-1] document = last_message.content if hasattr(last_message, 'content') else str(last_message) logger.info(f"Generating answer for: {original_query}") try: prompt = hub.pull("rlm/rag-prompt") rag_chain = prompt | llm response = rag_chain.invoke({"context": document, "question": original_query}) except Exception as e: logger.error(f"Error with hub prompt: {e}") # Fallback prompt fallback_prompt = PromptTemplate(template=""" Based on the following context, answer the question: Context: {context} Question: {question} Answer:""", input_variables=["context", "question"]) rag_chain = fallback_prompt | llm response = rag_chain.invoke({"context": document, "question": original_query}) logger.info(f"Generator Response: {response}") return {"messages": [response]} def rewrite(state: AgentState) -> AgentState: logger.info("\n - - - Rewrite Invocation - - -\n") messages = state["messages"] original_query = None for msg in messages: if isinstance(msg, HumanMessage): original_query = msg.content break if not original_query: original_query = "Tell me about Nikhil" logger.info(f"Rewriting query: {original_query}") rewrite_prompt = PromptTemplate(template=""" The original query was: {query} The retrieval didn't find relevant information. Please rewrite this query to be more specific and likely to find relevant information about Nikhil's background, experience, or qualifications. Rewritten query:""", input_variables=["query"]) chain = rewrite_prompt | llm response = chain.invoke({"query": original_query}) logger.info(f"Rewritten query: {response}") rewritten_message = HumanMessage(content=response.content if hasattr(response, 'content') else str(response)) return {"messages": [rewritten_message]} # # create a state graph # graph = StateGraph(AgentState) # graph.add_node("rag_agent", rag_agent) # graph.add_node("retriever_node", retriever_node) # graph.add_node("generator", generator) # graph.add_node("rewrite", rewrite) # graph.add_edge(START, "rag_agent") # graph.add_conditional_edges("rag_agent", tools_condition, {"tools": "retriever_node", END: END}) # graph.add_conditional_edges("retriever_node", document_quality, {"generator": "generator", "rewrite": "rewrite"}) # graph.add_edge("rewrite", "rag_agent") # graph.add_edge("generator", END) # app = graph.compile() def build_rag_state_graph(): logger.info("\n - - - Building RAG State Graph - - -\n") graph = StateGraph(AgentState) # stategraph definition # nodes graph.add_node("rag_agent", rag_agent) graph.add_node("retriever_node", retriever_node) graph.add_node("generator", generator) graph.add_node("rewrite", rewrite) # edges graph.add_edge(START, "rag_agent") graph.add_conditional_edges("rag_agent", tools_condition, {"tools": "retriever_node", END: END}) graph.add_conditional_edges("retriever_node", document_quality, {"generator": "generator", "rewrite": "rewrite"}) graph.add_edge("rewrite", "rag_agent") graph.add_edge("generator", END) logger.info("\n - - - RAG State Graph Built - - -\n") return graph.compile() # save compiled graph state to a PNG def save_mermaid_graph(app, output_path: str = "./graph.png") -> None: """Generate the app’s Mermaid diagram and save it as a PNG file.""" png_bytes = app.get_graph(xray=True).draw_mermaid_png() with open(output_path, "wb") as f: f.write(png_bytes) app = build_rag_state_graph() save_mermaid_graph(app) logger.info("\n - - - Graph saved to PNG - - -\n")