IIITDMJ_Chatbot / agent.py
PrinceDeepakSiddharth12's picture
Upload 5 files
bc0299d verified
import os
from typing import TypedDict, Annotated, List, Literal
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from langchain_core.documents import Document
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import add_messages
from dotenv import load_dotenv
load_dotenv()
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, streaming=True)
classification_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
db = FAISS.load_local("vectorstore/faiss_index2", embeddings, allow_dangerous_deserialization=True)
retriever = db.as_retriever(search_kwargs={'k': 3}) #
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
context: List[Document]
rewritten_query: str
query_type: Literal["simple_rag", "comparative_rag", "conversational"]
sub_queries: List[str]
def format_history_for_prompt(messages: list[BaseMessage]) -> str:
buffer = []
for msg in messages:
if isinstance(msg, HumanMessage): buffer.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage): buffer.append(f"AI: {msg.content}")
return "\n".join(buffer)
def format_docs_for_prompt(docs: List[Document]) -> str:
return "\n\n".join([doc.page_content for doc in docs])
def inject_system_prompt(state: AgentState) -> dict:
print("---NODE: INJECT_SYSTEM_PROMPT (START)---")
has_system_message = any(isinstance(msg, SystemMessage) for msg in state["messages"])
if not has_system_message:
system_prompt = (
"You are a helpful and professional assistant for IIITDMJ. "
"You must answer user questions based *only* on the retrieved context. "
"If the context does not contain the answer, you must state that "
"you do not have that information. Do not make up answers."
)
return {"messages": [SystemMessage(content=system_prompt)]}
return {}
def rewrite_query_node(state: AgentState) -> dict:
print("---NODE: REWRITE_QUERY---")
last_human_message = None
for msg in reversed(state["messages"]):
if isinstance(msg, HumanMessage):
last_human_message = msg
break
last_query = last_human_message.content if last_human_message else ""
chat_history = format_history_for_prompt(state["messages"][:-1])
if not chat_history:
print(f"--- Standalone Query: {last_query} ---")
return {"rewritten_query": last_query}
prompt = ChatPromptTemplate.from_template(
"""Given the following chat history and the user's latest question,
rewrite the user's question to be a standalone question...
Chat History: {chat_history}
Latest Question: {query}
Standalone Question:"""
)
rewrite_chain = prompt | classification_llm | StrOutputParser()
rewritten_query = rewrite_chain.invoke({"chat_history": chat_history, "query": last_query})
print(f"--- Rewritten Query: {rewritten_query} ---")
return {"rewritten_query": rewritten_query}
def classify_query_node(state: AgentState) -> dict:
print("---NODE: CLASSIFY_QUERY---")
query = state["rewritten_query"]
prompt = ChatPromptTemplate.from_template(
"""Classify the user's query into one of three categories:
1. **simple_rag**: ...
2. **comparative_rag**: ...
3. **conversational**: ...
Query: {query}
"""
)
classification_chain = prompt | classification_llm | StrOutputParser()
result = classification_chain.invoke({"query": query})
decision = "simple_rag"
if "comparative_rag" in result.lower(): decision = "comparative_rag"
elif "conversational" in result.lower(): decision = "conversational"
print(f"--- Decision: {decision} ---")
return {"query_type": decision}
def handle_chat_node(state: AgentState) -> dict:
"""
Path A: Generates an answer based *only* on the chat history.
"""
print("---NODE: HANDLE_CHAT---")
# query = state["rewritten_query"]
chat_history = format_history_for_prompt(state["messages"])
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful college assistant. Answer the user's question based on the chat history. Be conversational."),
("user", "Here is the chat history (including my last question):\n{chat_history}\n\nNow, please provide a conversational answer.")
])
generation_chain = prompt | llm | StrOutputParser()
answer = generation_chain.invoke({"chat_history": chat_history})
print(f"--- HANDLE_CHAT generated answer: {answer} ---")
return {"messages": [AIMessage(content=answer)]}
def retrieve_docs_node(state: AgentState) -> dict:
print("---NODE: RETRIEVE_DOCS (SIMPLE)---")
query = state["rewritten_query"]
documents = retriever.invoke(query)
print("\n--- RETRIEVED CONTEXT ---")
if documents:
for i, doc in enumerate(documents):
print(f"DOC {i+1}: Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}")
else: print("!!! No context retrieved. !!!")
print("---------------------------\n")
return {"context": documents}
def generate_answer_node(state: AgentState) -> dict:
print("---NODE: GENERATE_ANSWER (SIMPLE)---")
query = state["rewritten_query"]
context_docs = state["context"]
context_str = format_docs_for_prompt(context_docs)
prompt = ChatPromptTemplate.from_messages([
("system", (
"You are a helpful assistant. Answer the user's question based *only* on the retrieved context. "
"If the context is empty or irrelevant, you *must* state that you do not have the information "
"and recommend visiting the official Indian Institute of Information Technology, Design and Manufacturing, Jabalpur (IIITDM Jabalpur) website (https://www.iiitdmj.ac.in/) for more details."
)),
("user", "Context:\n{context}\n\nQuestion:\n{query}")
])
generation_chain = prompt | llm | StrOutputParser()
answer = generation_chain.invoke({"context": context_str, "query": query})
sources = []
if context_docs:
for i, doc in enumerate(context_docs):
source_file = doc.metadata.get('source', 'N/A')
source_name = source_file.split('/')[-1]
page_num = doc.metadata.get('page', 'N/A')
sources.append(f" {i+1}. {source_name} (Page: {page_num})")
if sources and "website" not in answer:
pretty_answer = answer + "\n--- \n**Sources:**\n" + "\n".join(sources)
else:
pretty_answer = answer
return {"messages": [AIMessage(content=pretty_answer)]}
def decompose_query_node(state: AgentState) -> dict:
print("---NODE: DECOMPOSE_QUERY---")
query = state["rewritten_query"]
prompt = ChatPromptTemplate.from_template(
"""You are a query decomposition assistant...
Query: {query}
Respond with a JSON object..."""
)
parser = JsonOutputParser()
decomposition_chain = prompt | classification_llm | parser
result = decomposition_chain.invoke({"query": query})
print(f"--- Sub-queries: {result['queries']} ---")
return {"sub_queries": result['queries']}
def retrieve_multi_docs_node(state: AgentState) -> dict:
print("---NODE: RETRIEVE_DOCS (MULTI)---")
sub_queries = state["sub_queries"]
all_docs = []
for query in sub_queries:
documents = retriever.invoke(query)
all_docs.extend(documents)
unique_docs_map = {doc.page_content: doc for doc in all_docs}
unique_docs = list(unique_docs_map.values())
print("\n--- RETRIEVED CONTEXT (MULTI) ---")
if unique_docs:
for i, doc in enumerate(unique_docs):
print(f"DOC {i+1}: Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}")
else: print("!!! No context retrieved. !!!")
print("---------------------------\n")
return {"context": unique_docs}
def generate_synthesized_answer_node(state: AgentState) -> dict:
print("---NODE: GENERATE_ANSWER (SYNTHESIZED)---")
query = state["rewritten_query"]
context_docs = state["context"]
context_str = format_docs_for_prompt(context_docs)
prompt = ChatPromptTemplate.from_messages([
("system", (
"You are a helpful assistant. Your task is to answer a comparative question based on the provided context. "
"Synthesize the information from the context to form a comprehensive answer. "
"If the context is insufficient, you *must* state that you do not have the information "
"and recommend visiting the official Indian Institute of Information Technology, Design and Manufacturing, Jabalpur (IIITDM Jabalpur) website (https://www.iiitdmj.ac.in/) for more details."
)),
("user", (
"Here is the context I've gathered:\n{context}\n\n"
"Now, please answer this original question:\n{query}"
))
])
generation_chain = prompt | llm | StrOutputParser()
answer = generation_chain.invoke({"context": context_str, "query": query})
sources = []
if context_docs:
for i, doc in enumerate(context_docs):
source_file = doc.metadata.get('source', 'N/A')
source_name = source_file.split('/')[-1]
page_num = doc.metadata.get('page', 'N/A')
sources.append(f" {i+1}. {source_name} (Page: {page_num})")
if sources and "website" not in answer:
pretty_answer = answer + "\n--- \n**Sources:**\n" + "\n".join(sources)
else:
pretty_answer = answer
return {"messages": [AIMessage(content=pretty_answer)]}
def router(state: AgentState) -> Literal["conversational", "simple_rag", "comparative_rag"]:
print(f"--- ROUTING TO: {state['query_type']} ---")
return state["query_type"]
checkpointer = MemorySaver()
def build_graph():
workflow = StateGraph(AgentState)
workflow.add_node("inject_system_prompt", inject_system_prompt)
workflow.add_node("rewrite_query", rewrite_query_node)
workflow.add_node("classify_query", classify_query_node)
workflow.add_node("handle_chat", handle_chat_node)
workflow.add_node("retrieve_docs", retrieve_docs_node)
workflow.add_node("generate_answer", generate_answer_node)
workflow.add_node("decompose_query", decompose_query_node)
workflow.add_node("retrieve_multi_docs", retrieve_multi_docs_node)
workflow.add_node("generate_synthesized_answer", generate_synthesized_answer_node)
workflow.set_entry_point("inject_system_prompt")
workflow.add_edge("inject_system_prompt", "rewrite_query")
workflow.add_edge("rewrite_query", "classify_query")
workflow.add_conditional_edges(
"classify_query",
router,
{
"conversational": "handle_chat",
"simple_rag": "retrieve_docs",
"comparative_rag": "decompose_query"
}
)
workflow.add_edge("handle_chat", END)
workflow.add_edge("retrieve_docs", "generate_answer")
workflow.add_edge("generate_answer", END)
workflow.add_edge("decompose_query", "retrieve_multi_docs")
workflow.add_edge("retrieve_multi_docs", "generate_synthesized_answer")
workflow.add_edge("generate_synthesized_answer", END)
app = workflow.compile(checkpointer=checkpointer)
return app
chatbot = build_graph()
if __name__ == "__main__":
config = {"configurable": {"thread_id": "test-direct-run-1"}}
print("\n--- Testing Direct Run ---")
inputs = {"messages": [HumanMessage(content="What is the name of director?")]}
for event in chatbot.stream(inputs, config, stream_mode="values"):
if "messages" in event:
event["messages"][-1].pretty_print()