Upload 5 files
Browse files- agent.py +281 -0
- app.py +159 -0
- check.ipynb +305 -0
- ingest.py +43 -0
- requirements.txt +13 -0
agent.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import TypedDict, Annotated, List, Literal
|
| 3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 4 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 5 |
+
from langchain_community.vectorstores import FAISS
|
| 6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 7 |
+
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
| 8 |
+
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
|
| 9 |
+
from langchain_core.documents import Document
|
| 10 |
+
from langgraph.graph import StateGraph, END
|
| 11 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 12 |
+
from langgraph.graph import add_messages
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, streaming=True)
|
| 18 |
+
classification_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0)
|
| 19 |
+
|
| 20 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
|
| 21 |
+
db = FAISS.load_local("vectorstore/faiss_index2", embeddings, allow_dangerous_deserialization=True)
|
| 22 |
+
retriever = db.as_retriever(search_kwargs={'k': 3}) #
|
| 23 |
+
|
| 24 |
+
class AgentState(TypedDict):
|
| 25 |
+
messages: Annotated[list, add_messages]
|
| 26 |
+
context: List[Document]
|
| 27 |
+
rewritten_query: str
|
| 28 |
+
query_type: Literal["simple_rag", "comparative_rag", "conversational"]
|
| 29 |
+
sub_queries: List[str]
|
| 30 |
+
|
| 31 |
+
def format_history_for_prompt(messages: list[BaseMessage]) -> str:
|
| 32 |
+
buffer = []
|
| 33 |
+
for msg in messages:
|
| 34 |
+
if isinstance(msg, HumanMessage): buffer.append(f"Human: {msg.content}")
|
| 35 |
+
elif isinstance(msg, AIMessage): buffer.append(f"AI: {msg.content}")
|
| 36 |
+
return "\n".join(buffer)
|
| 37 |
+
|
| 38 |
+
def format_docs_for_prompt(docs: List[Document]) -> str:
|
| 39 |
+
return "\n\n".join([doc.page_content for doc in docs])
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def inject_system_prompt(state: AgentState) -> dict:
|
| 43 |
+
print("---NODE: INJECT_SYSTEM_PROMPT (START)---")
|
| 44 |
+
has_system_message = any(isinstance(msg, SystemMessage) for msg in state["messages"])
|
| 45 |
+
if not has_system_message:
|
| 46 |
+
system_prompt = (
|
| 47 |
+
"You are a helpful and professional assistant for IIITDMJ. "
|
| 48 |
+
"You must answer user questions based *only* on the retrieved context. "
|
| 49 |
+
"If the context does not contain the answer, you must state that "
|
| 50 |
+
"you do not have that information. Do not make up answers."
|
| 51 |
+
)
|
| 52 |
+
return {"messages": [SystemMessage(content=system_prompt)]}
|
| 53 |
+
return {}
|
| 54 |
+
|
| 55 |
+
def rewrite_query_node(state: AgentState) -> dict:
|
| 56 |
+
print("---NODE: REWRITE_QUERY---")
|
| 57 |
+
last_human_message = None
|
| 58 |
+
for msg in reversed(state["messages"]):
|
| 59 |
+
if isinstance(msg, HumanMessage):
|
| 60 |
+
last_human_message = msg
|
| 61 |
+
break
|
| 62 |
+
last_query = last_human_message.content if last_human_message else ""
|
| 63 |
+
chat_history = format_history_for_prompt(state["messages"][:-1])
|
| 64 |
+
|
| 65 |
+
if not chat_history:
|
| 66 |
+
print(f"--- Standalone Query: {last_query} ---")
|
| 67 |
+
return {"rewritten_query": last_query}
|
| 68 |
+
|
| 69 |
+
prompt = ChatPromptTemplate.from_template(
|
| 70 |
+
"""Given the following chat history and the user's latest question,
|
| 71 |
+
rewrite the user's question to be a standalone question...
|
| 72 |
+
Chat History: {chat_history}
|
| 73 |
+
Latest Question: {query}
|
| 74 |
+
Standalone Question:"""
|
| 75 |
+
)
|
| 76 |
+
rewrite_chain = prompt | classification_llm | StrOutputParser()
|
| 77 |
+
rewritten_query = rewrite_chain.invoke({"chat_history": chat_history, "query": last_query})
|
| 78 |
+
print(f"--- Rewritten Query: {rewritten_query} ---")
|
| 79 |
+
return {"rewritten_query": rewritten_query}
|
| 80 |
+
|
| 81 |
+
def classify_query_node(state: AgentState) -> dict:
|
| 82 |
+
print("---NODE: CLASSIFY_QUERY---")
|
| 83 |
+
query = state["rewritten_query"]
|
| 84 |
+
prompt = ChatPromptTemplate.from_template(
|
| 85 |
+
"""Classify the user's query into one of three categories:
|
| 86 |
+
1. **simple_rag**: ...
|
| 87 |
+
2. **comparative_rag**: ...
|
| 88 |
+
3. **conversational**: ...
|
| 89 |
+
Query: {query}
|
| 90 |
+
"""
|
| 91 |
+
)
|
| 92 |
+
classification_chain = prompt | classification_llm | StrOutputParser()
|
| 93 |
+
result = classification_chain.invoke({"query": query})
|
| 94 |
+
|
| 95 |
+
decision = "simple_rag"
|
| 96 |
+
if "comparative_rag" in result.lower(): decision = "comparative_rag"
|
| 97 |
+
elif "conversational" in result.lower(): decision = "conversational"
|
| 98 |
+
print(f"--- Decision: {decision} ---")
|
| 99 |
+
return {"query_type": decision}
|
| 100 |
+
|
| 101 |
+
def handle_chat_node(state: AgentState) -> dict:
|
| 102 |
+
"""
|
| 103 |
+
Path A: Generates an answer based *only* on the chat history.
|
| 104 |
+
"""
|
| 105 |
+
print("---NODE: HANDLE_CHAT---")
|
| 106 |
+
# query = state["rewritten_query"]
|
| 107 |
+
chat_history = format_history_for_prompt(state["messages"])
|
| 108 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 109 |
+
("system", "You are a helpful college assistant. Answer the user's question based on the chat history. Be conversational."),
|
| 110 |
+
("user", "Here is the chat history (including my last question):\n{chat_history}\n\nNow, please provide a conversational answer.")
|
| 111 |
+
])
|
| 112 |
+
generation_chain = prompt | llm | StrOutputParser()
|
| 113 |
+
answer = generation_chain.invoke({"chat_history": chat_history})
|
| 114 |
+
|
| 115 |
+
print(f"--- HANDLE_CHAT generated answer: {answer} ---")
|
| 116 |
+
|
| 117 |
+
return {"messages": [AIMessage(content=answer)]}
|
| 118 |
+
|
| 119 |
+
def retrieve_docs_node(state: AgentState) -> dict:
|
| 120 |
+
print("---NODE: RETRIEVE_DOCS (SIMPLE)---")
|
| 121 |
+
query = state["rewritten_query"]
|
| 122 |
+
documents = retriever.invoke(query)
|
| 123 |
+
print("\n--- RETRIEVED CONTEXT ---")
|
| 124 |
+
if documents:
|
| 125 |
+
for i, doc in enumerate(documents):
|
| 126 |
+
print(f"DOC {i+1}: Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}")
|
| 127 |
+
else: print("!!! No context retrieved. !!!")
|
| 128 |
+
print("---------------------------\n")
|
| 129 |
+
return {"context": documents}
|
| 130 |
+
|
| 131 |
+
def generate_answer_node(state: AgentState) -> dict:
|
| 132 |
+
print("---NODE: GENERATE_ANSWER (SIMPLE)---")
|
| 133 |
+
query = state["rewritten_query"]
|
| 134 |
+
context_docs = state["context"]
|
| 135 |
+
context_str = format_docs_for_prompt(context_docs)
|
| 136 |
+
|
| 137 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 138 |
+
("system", (
|
| 139 |
+
"You are a helpful assistant. Answer the user's question based *only* on the retrieved context. "
|
| 140 |
+
"If the context is empty or irrelevant, you *must* state that you do not have the information "
|
| 141 |
+
"and recommend visiting the official Indian Institute of Information Technology, Design and Manufacturing, Jabalpur (IIITDM Jabalpur) website (https://www.iiitdmj.ac.in/) for more details."
|
| 142 |
+
)),
|
| 143 |
+
("user", "Context:\n{context}\n\nQuestion:\n{query}")
|
| 144 |
+
])
|
| 145 |
+
|
| 146 |
+
generation_chain = prompt | llm | StrOutputParser()
|
| 147 |
+
answer = generation_chain.invoke({"context": context_str, "query": query})
|
| 148 |
+
|
| 149 |
+
sources = []
|
| 150 |
+
if context_docs:
|
| 151 |
+
for i, doc in enumerate(context_docs):
|
| 152 |
+
source_file = doc.metadata.get('source', 'N/A')
|
| 153 |
+
source_name = source_file.split('/')[-1]
|
| 154 |
+
page_num = doc.metadata.get('page', 'N/A')
|
| 155 |
+
sources.append(f" {i+1}. {source_name} (Page: {page_num})")
|
| 156 |
+
|
| 157 |
+
if sources and "website" not in answer:
|
| 158 |
+
pretty_answer = answer + "\n--- \n**Sources:**\n" + "\n".join(sources)
|
| 159 |
+
else:
|
| 160 |
+
pretty_answer = answer
|
| 161 |
+
|
| 162 |
+
return {"messages": [AIMessage(content=pretty_answer)]}
|
| 163 |
+
|
| 164 |
+
def decompose_query_node(state: AgentState) -> dict:
|
| 165 |
+
print("---NODE: DECOMPOSE_QUERY---")
|
| 166 |
+
query = state["rewritten_query"]
|
| 167 |
+
prompt = ChatPromptTemplate.from_template(
|
| 168 |
+
"""You are a query decomposition assistant...
|
| 169 |
+
Query: {query}
|
| 170 |
+
Respond with a JSON object..."""
|
| 171 |
+
)
|
| 172 |
+
parser = JsonOutputParser()
|
| 173 |
+
decomposition_chain = prompt | classification_llm | parser
|
| 174 |
+
result = decomposition_chain.invoke({"query": query})
|
| 175 |
+
print(f"--- Sub-queries: {result['queries']} ---")
|
| 176 |
+
return {"sub_queries": result['queries']}
|
| 177 |
+
|
| 178 |
+
def retrieve_multi_docs_node(state: AgentState) -> dict:
|
| 179 |
+
print("---NODE: RETRIEVE_DOCS (MULTI)---")
|
| 180 |
+
sub_queries = state["sub_queries"]
|
| 181 |
+
all_docs = []
|
| 182 |
+
for query in sub_queries:
|
| 183 |
+
documents = retriever.invoke(query)
|
| 184 |
+
all_docs.extend(documents)
|
| 185 |
+
unique_docs_map = {doc.page_content: doc for doc in all_docs}
|
| 186 |
+
unique_docs = list(unique_docs_map.values())
|
| 187 |
+
print("\n--- RETRIEVED CONTEXT (MULTI) ---")
|
| 188 |
+
if unique_docs:
|
| 189 |
+
for i, doc in enumerate(unique_docs):
|
| 190 |
+
print(f"DOC {i+1}: Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}")
|
| 191 |
+
else: print("!!! No context retrieved. !!!")
|
| 192 |
+
print("---------------------------\n")
|
| 193 |
+
return {"context": unique_docs}
|
| 194 |
+
|
| 195 |
+
def generate_synthesized_answer_node(state: AgentState) -> dict:
|
| 196 |
+
print("---NODE: GENERATE_ANSWER (SYNTHESIZED)---")
|
| 197 |
+
query = state["rewritten_query"]
|
| 198 |
+
context_docs = state["context"]
|
| 199 |
+
context_str = format_docs_for_prompt(context_docs)
|
| 200 |
+
|
| 201 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 202 |
+
("system", (
|
| 203 |
+
"You are a helpful assistant. Your task is to answer a comparative question based on the provided context. "
|
| 204 |
+
"Synthesize the information from the context to form a comprehensive answer. "
|
| 205 |
+
"If the context is insufficient, you *must* state that you do not have the information "
|
| 206 |
+
"and recommend visiting the official Indian Institute of Information Technology, Design and Manufacturing, Jabalpur (IIITDM Jabalpur) website (https://www.iiitdmj.ac.in/) for more details."
|
| 207 |
+
)),
|
| 208 |
+
("user", (
|
| 209 |
+
"Here is the context I've gathered:\n{context}\n\n"
|
| 210 |
+
"Now, please answer this original question:\n{query}"
|
| 211 |
+
))
|
| 212 |
+
])
|
| 213 |
+
|
| 214 |
+
generation_chain = prompt | llm | StrOutputParser()
|
| 215 |
+
answer = generation_chain.invoke({"context": context_str, "query": query})
|
| 216 |
+
|
| 217 |
+
sources = []
|
| 218 |
+
if context_docs:
|
| 219 |
+
for i, doc in enumerate(context_docs):
|
| 220 |
+
source_file = doc.metadata.get('source', 'N/A')
|
| 221 |
+
source_name = source_file.split('/')[-1]
|
| 222 |
+
page_num = doc.metadata.get('page', 'N/A')
|
| 223 |
+
sources.append(f" {i+1}. {source_name} (Page: {page_num})")
|
| 224 |
+
|
| 225 |
+
if sources and "website" not in answer:
|
| 226 |
+
pretty_answer = answer + "\n--- \n**Sources:**\n" + "\n".join(sources)
|
| 227 |
+
else:
|
| 228 |
+
pretty_answer = answer
|
| 229 |
+
|
| 230 |
+
return {"messages": [AIMessage(content=pretty_answer)]}
|
| 231 |
+
|
| 232 |
+
def router(state: AgentState) -> Literal["conversational", "simple_rag", "comparative_rag"]:
|
| 233 |
+
print(f"--- ROUTING TO: {state['query_type']} ---")
|
| 234 |
+
return state["query_type"]
|
| 235 |
+
|
| 236 |
+
checkpointer = MemorySaver()
|
| 237 |
+
|
| 238 |
+
def build_graph():
|
| 239 |
+
workflow = StateGraph(AgentState)
|
| 240 |
+
|
| 241 |
+
workflow.add_node("inject_system_prompt", inject_system_prompt)
|
| 242 |
+
workflow.add_node("rewrite_query", rewrite_query_node)
|
| 243 |
+
workflow.add_node("classify_query", classify_query_node)
|
| 244 |
+
workflow.add_node("handle_chat", handle_chat_node)
|
| 245 |
+
workflow.add_node("retrieve_docs", retrieve_docs_node)
|
| 246 |
+
workflow.add_node("generate_answer", generate_answer_node)
|
| 247 |
+
workflow.add_node("decompose_query", decompose_query_node)
|
| 248 |
+
workflow.add_node("retrieve_multi_docs", retrieve_multi_docs_node)
|
| 249 |
+
workflow.add_node("generate_synthesized_answer", generate_synthesized_answer_node)
|
| 250 |
+
|
| 251 |
+
workflow.set_entry_point("inject_system_prompt")
|
| 252 |
+
workflow.add_edge("inject_system_prompt", "rewrite_query")
|
| 253 |
+
workflow.add_edge("rewrite_query", "classify_query")
|
| 254 |
+
workflow.add_conditional_edges(
|
| 255 |
+
"classify_query",
|
| 256 |
+
router,
|
| 257 |
+
{
|
| 258 |
+
"conversational": "handle_chat",
|
| 259 |
+
"simple_rag": "retrieve_docs",
|
| 260 |
+
"comparative_rag": "decompose_query"
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
workflow.add_edge("handle_chat", END)
|
| 264 |
+
workflow.add_edge("retrieve_docs", "generate_answer")
|
| 265 |
+
workflow.add_edge("generate_answer", END)
|
| 266 |
+
workflow.add_edge("decompose_query", "retrieve_multi_docs")
|
| 267 |
+
workflow.add_edge("retrieve_multi_docs", "generate_synthesized_answer")
|
| 268 |
+
workflow.add_edge("generate_synthesized_answer", END)
|
| 269 |
+
|
| 270 |
+
app = workflow.compile(checkpointer=checkpointer)
|
| 271 |
+
return app
|
| 272 |
+
|
| 273 |
+
chatbot = build_graph()
|
| 274 |
+
|
| 275 |
+
if __name__ == "__main__":
|
| 276 |
+
config = {"configurable": {"thread_id": "test-direct-run-1"}}
|
| 277 |
+
print("\n--- Testing Direct Run ---")
|
| 278 |
+
inputs = {"messages": [HumanMessage(content="What is the name of director?")]}
|
| 279 |
+
for event in chatbot.stream(inputs, config, stream_mode="values"):
|
| 280 |
+
if "messages" in event:
|
| 281 |
+
event["messages"][-1].pretty_print()
|
app.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from agent import chatbot, classification_llm
|
| 3 |
+
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
|
| 4 |
+
import uuid
|
| 5 |
+
import asyncio
|
| 6 |
+
|
| 7 |
+
def generate_thread_id():
|
| 8 |
+
thread_id= uuid.uuid4()
|
| 9 |
+
return thread_id
|
| 10 |
+
|
| 11 |
+
def reset_chat():
|
| 12 |
+
thread_id=uuid.uuid4()
|
| 13 |
+
st.session_state['thread_id']=thread_id
|
| 14 |
+
add_thread(st.session_state['thread_id'])
|
| 15 |
+
st.session_state['message_history']=[]
|
| 16 |
+
|
| 17 |
+
def add_thread(thread_id):
|
| 18 |
+
if thread_id not in st.session_state['chat_threads']:
|
| 19 |
+
st.session_state['chat_threads'].append(thread_id)
|
| 20 |
+
st.session_state['thread_titles'][thread_id]=f"New Chat {len(st.session_state['chat_threads'])}"
|
| 21 |
+
|
| 22 |
+
def load_conversation(thread_id):
|
| 23 |
+
try:
|
| 24 |
+
state= chatbot.get_state(config={'configurable' : {'thread_id': thread_id}})
|
| 25 |
+
raw_messages = state.values.get('messages', []) if state else []
|
| 26 |
+
return [msg for msg in raw_messages if isinstance(msg, BaseMessage)]
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error loading conversation for thread {thread_id}: {e}")
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
def generate_title(query):
|
| 32 |
+
print("--- Generating Title ---")
|
| 33 |
+
try:
|
| 34 |
+
prompt = f"Summarize this query into a very short title (max 5 words): {query}"
|
| 35 |
+
response = classification_llm.invoke(prompt)
|
| 36 |
+
title = response.content.strip().strip('"')
|
| 37 |
+
return title if title else "Chat"
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Error generating title: {e}")
|
| 40 |
+
return "Chat"
|
| 41 |
+
|
| 42 |
+
if 'message_history' not in st.session_state: st.session_state['message_history']=[]
|
| 43 |
+
if 'thread_id' not in st.session_state: st.session_state['thread_id']=generate_thread_id()
|
| 44 |
+
if 'chat_threads' not in st.session_state: st.session_state['chat_threads']=[]
|
| 45 |
+
if 'thread_titles' not in st.session_state: st.session_state['thread_titles']={}
|
| 46 |
+
add_thread(st.session_state['thread_id'])
|
| 47 |
+
|
| 48 |
+
st.sidebar.title("IIITDMJ Chatbot")
|
| 49 |
+
if st.sidebar.button("➕ New Chat"):
|
| 50 |
+
reset_chat()
|
| 51 |
+
st.rerun()
|
| 52 |
+
st.sidebar.header("My Conversations")
|
| 53 |
+
for thread_id in st.session_state['chat_threads'][::-1]:
|
| 54 |
+
title=st.session_state['thread_titles'].get(thread_id,"Untitled Chat")
|
| 55 |
+
if st.sidebar.button(title, key=f"thread_{thread_id}", use_container_width=True):
|
| 56 |
+
st.session_state['thread_id']=thread_id
|
| 57 |
+
messages= load_conversation(thread_id)
|
| 58 |
+
temp_messages = []
|
| 59 |
+
for msg in messages:
|
| 60 |
+
if isinstance(msg, SystemMessage): continue
|
| 61 |
+
role = 'user' if isinstance(msg, HumanMessage) else 'assistant'
|
| 62 |
+
temp_messages.append({'role': role, 'content': msg.content})
|
| 63 |
+
st.session_state['message_history'] = temp_messages
|
| 64 |
+
st.rerun()
|
| 65 |
+
|
| 66 |
+
st.title("IIITDMJ College Assistant")
|
| 67 |
+
st.caption("This bot uses a local vector store and LangGraph to answer your questions.")
|
| 68 |
+
|
| 69 |
+
for message in st.session_state['message_history']:
|
| 70 |
+
with st.chat_message(message['role']):
|
| 71 |
+
if message['role'] == 'assistant':
|
| 72 |
+
st.markdown(f"<div style='font-size: 15px;'>{message['content']}</div>", unsafe_allow_html=True)
|
| 73 |
+
else:
|
| 74 |
+
st.markdown(message['content'])
|
| 75 |
+
|
| 76 |
+
user_input=st.chat_input("Ask about IIITDMJ...")
|
| 77 |
+
|
| 78 |
+
if user_input:
|
| 79 |
+
|
| 80 |
+
CONFIG={'configurable' : {'thread_id': st.session_state['thread_id']}}
|
| 81 |
+
|
| 82 |
+
st.session_state['message_history'].append({'role':'user','content':user_input})
|
| 83 |
+
with st.chat_message('user'):
|
| 84 |
+
st.markdown(user_input)
|
| 85 |
+
|
| 86 |
+
with st.chat_message('assistant'):
|
| 87 |
+
placeholder = st.empty()
|
| 88 |
+
ai_message_content = ""
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
print(f"\n--- Streaming response for Thread ID: {st.session_state['thread_id']} ---")
|
| 92 |
+
|
| 93 |
+
async def stream_agent_events(stream_placeholder):
|
| 94 |
+
local_ai_message_content_streamed = ""
|
| 95 |
+
local_final_node_output = None
|
| 96 |
+
local_final_node_name = ""
|
| 97 |
+
|
| 98 |
+
async for event in chatbot.astream_events(
|
| 99 |
+
{'messages': [HumanMessage(content=user_input)]},
|
| 100 |
+
config=CONFIG,
|
| 101 |
+
version="v1"
|
| 102 |
+
):
|
| 103 |
+
kind = event["event"]
|
| 104 |
+
name = event["name"]
|
| 105 |
+
|
| 106 |
+
if kind == "on_chat_model_stream":
|
| 107 |
+
if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
|
| 108 |
+
chunk_content = event["data"]["chunk"].content
|
| 109 |
+
if chunk_content:
|
| 110 |
+
local_ai_message_content_streamed += chunk_content
|
| 111 |
+
stream_placeholder.markdown(f"<div style='font-size: 15px;'>{local_ai_message_content_streamed}▌</div>", unsafe_allow_html=True)
|
| 112 |
+
|
| 113 |
+
if kind == "on_chain_end":
|
| 114 |
+
if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
|
| 115 |
+
if "output" in event.get("data", {}) and isinstance(event["data"]["output"], dict):
|
| 116 |
+
local_final_node_output = event["data"]["output"]
|
| 117 |
+
local_final_node_name = name
|
| 118 |
+
print(f"--- Captured final output from node: {name} ---")
|
| 119 |
+
|
| 120 |
+
return local_ai_message_content_streamed, local_final_node_output, local_final_node_name
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
streamed_content, final_output, final_name = asyncio.run(stream_agent_events(placeholder))
|
| 124 |
+
|
| 125 |
+
if not streamed_content and final_output:
|
| 126 |
+
print(f"--- Using fallback: No stream content captured. Using final output from {final_name}. ---")
|
| 127 |
+
if "messages" in final_output and final_output["messages"]:
|
| 128 |
+
ai_message_content = final_output["messages"][-1].content
|
| 129 |
+
placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
|
| 130 |
+
else:
|
| 131 |
+
print(f"--- Fallback failed: Final output from {final_name} had unexpected format: {final_output} ---")
|
| 132 |
+
ai_message_content = "Sorry, I couldn't generate a response (fallback error)."
|
| 133 |
+
placeholder.markdown(ai_message_content)
|
| 134 |
+
|
| 135 |
+
elif streamed_content:
|
| 136 |
+
ai_message_content = streamed_content
|
| 137 |
+
placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
|
| 138 |
+
else:
|
| 139 |
+
print("--- Fallback failed: No stream content and no final output captured. ---")
|
| 140 |
+
ai_message_content = "Sorry, I couldn't generate a response (capture error)."
|
| 141 |
+
placeholder.markdown(ai_message_content)
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
st.error(f"An error occurred: {e}")
|
| 145 |
+
print(f"ERROR DURING STREAM/FALLBACK: {e}")
|
| 146 |
+
ai_message_content = "Sorry, I encountered an error during execution."
|
| 147 |
+
placeholder.markdown(ai_message_content)
|
| 148 |
+
|
| 149 |
+
if not ai_message_content:
|
| 150 |
+
ai_message_content = "Sorry, I couldn't generate a response."
|
| 151 |
+
|
| 152 |
+
st.session_state['message_history'].append({'role':'assistant','content':ai_message_content})
|
| 153 |
+
|
| 154 |
+
current_id=st.session_state['thread_id']
|
| 155 |
+
current_title=st.session_state['thread_titles'].get(current_id,"New Chat")
|
| 156 |
+
if current_title.startswith("New Chat") and len(st.session_state['message_history']) <= 2:
|
| 157 |
+
summarized_title = generate_title(user_input)
|
| 158 |
+
st.session_state['thread_titles'][current_id] = summarized_title
|
| 159 |
+
st.rerun()
|
check.ipynb
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 20,
|
| 6 |
+
"id": "081405cc",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"data": {
|
| 11 |
+
"text/plain": [
|
| 12 |
+
"True"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
"execution_count": 20,
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"output_type": "execute_result"
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"source": [
|
| 21 |
+
"import os\n",
|
| 22 |
+
"from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader\n",
|
| 23 |
+
"from langchain_community.vectorstores import FAISS\n",
|
| 24 |
+
"from langchain_community.embeddings import HuggingFaceEmbeddings\n",
|
| 25 |
+
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
|
| 26 |
+
"from dotenv import load_dotenv\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"load_dotenv()"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"execution_count": 21,
|
| 34 |
+
"id": "3c40840f",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"MODEL_NAME = \"sentence-transformers/all-MiniLM-L12-v2\"\n",
|
| 39 |
+
"DATA_PATH=\"data/\""
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": 22,
|
| 45 |
+
"id": "90fc0a47",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [
|
| 48 |
+
{
|
| 49 |
+
"name": "stdout",
|
| 50 |
+
"output_type": "stream",
|
| 51 |
+
"text": [
|
| 52 |
+
"Loading documents from data/...\n",
|
| 53 |
+
"Loaded 2087 PDF document(s).\n",
|
| 54 |
+
"Split into 25938 chunks.\n",
|
| 55 |
+
"Creating and saving FAISS vector store...\n"
|
| 56 |
+
]
|
| 57 |
+
}
|
| 58 |
+
],
|
| 59 |
+
"source": [
|
| 60 |
+
"embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"print(f\"Loading documents from {DATA_PATH}...\")\n",
|
| 63 |
+
"loader = DirectoryLoader(\n",
|
| 64 |
+
" DATA_PATH,\n",
|
| 65 |
+
" glob='*.pdf', \n",
|
| 66 |
+
" loader_cls=PyPDFLoader \n",
|
| 67 |
+
")\n",
|
| 68 |
+
"documents = loader.load()\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"if not documents:\n",
|
| 71 |
+
" print(\"No PDF documents found. Make sure your PDFs are in the /data folder.\")\n",
|
| 72 |
+
" exit()\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"print(f\"Loaded {len(documents)} PDF document(s).\")\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# 3. Split Documents\n",
|
| 77 |
+
"text_splitter = RecursiveCharacterTextSplitter(\n",
|
| 78 |
+
" chunk_size=300, \n",
|
| 79 |
+
" chunk_overlap=200,\n",
|
| 80 |
+
" separators=[\"\\n\\n\", \"\\n\", \".\", \"!\", \"?\", \" \", \"\"]\n",
|
| 81 |
+
" )\n",
|
| 82 |
+
"docs = text_splitter.split_documents(documents)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"print(f\"Split into {len(docs)} chunks.\")\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# 4. Create and Save FAISS Vector Store\n",
|
| 87 |
+
"print(\"Creating and saving FAISS vector store...\")\n",
|
| 88 |
+
"db = FAISS.from_documents(docs, embeddings)"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"id": "9ca0ee2b",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [
|
| 97 |
+
{
|
| 98 |
+
"name": "stdout",
|
| 99 |
+
"output_type": "stream",
|
| 100 |
+
"text": [
|
| 101 |
+
"Loading embedding model: sentence-transformers/all-MiniLM-L12-v2...\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"✅ Retriever is ready.\n",
|
| 104 |
+
" Enter your query to test. Type 'exit' to quit.\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"--- Retrieving docs for: 'who is director' ---\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"--- Document 1 ---\n",
|
| 109 |
+
"Source: data/iiitdmj_crawl_data_1.pdf\n",
|
| 110 |
+
"Page: 133\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"Content:\n",
|
| 113 |
+
"director@iiitdmj.ac.in\n",
|
| 114 |
+
"2.\n",
|
| 115 |
+
"Deputy Director\n",
|
| 116 |
+
"To be nominated on appointment\n",
|
| 117 |
+
"3.\n",
|
| 118 |
+
"Deans (Ex-officio)\n",
|
| 119 |
+
"1. Dr. Mukesh Kumar Roy\n",
|
| 120 |
+
"Faculty-in-Charge (Student Affairs)\n",
|
| 121 |
+
"mkroy@iiitdmj.ac.in\n",
|
| 122 |
+
"2. Prof. V. K. Gupta\n",
|
| 123 |
+
"Professor In-charge (Academic)\n",
|
| 124 |
+
"dean.acad@iiitdmj.ac.in\n",
|
| 125 |
+
"3. Prof. Pritee Khanna\n",
|
| 126 |
+
"--------------------\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"--- Document 2 ---\n",
|
| 129 |
+
"Source: data/IIITDM Jabalpur.pdf\n",
|
| 130 |
+
"Page: 2\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"Content:\n",
|
| 133 |
+
" The Deputy Director (to be nominated on appointment) \n",
|
| 134 |
+
" The Deans \n",
|
| 135 |
+
" The Heads of various disciplines and \n",
|
| 136 |
+
" The Registrar \n",
|
| 137 |
+
" \n",
|
| 138 |
+
" \n",
|
| 139 |
+
" \n",
|
| 140 |
+
" \n",
|
| 141 |
+
"Building And Works Committee \n",
|
| 142 |
+
"S. No. Name Designation \n",
|
| 143 |
+
"1. Prof. Bhartendu Kumar Singh \n",
|
| 144 |
+
"Director \n",
|
| 145 |
+
"PDPM-IIITDM Jabalpur \n",
|
| 146 |
+
"director@iiitdmj.ac.in\n",
|
| 147 |
+
"--------------------\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"--- Document 3 ---\n",
|
| 150 |
+
"Source: data/iiitdmj_crawl_data_1.pdf\n",
|
| 151 |
+
"Page: 133\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"Content:\n",
|
| 154 |
+
"S. No.\n",
|
| 155 |
+
"Name\n",
|
| 156 |
+
"Address\n",
|
| 157 |
+
"1.\n",
|
| 158 |
+
"Director as Chairperson (Ex-officio)\n",
|
| 159 |
+
"Prof. Bhartendu K Singh (Director)\n",
|
| 160 |
+
"director@iiitdmj.ac.in\n",
|
| 161 |
+
"2.\n",
|
| 162 |
+
"Deputy Director\n",
|
| 163 |
+
"To be nominated on appointment\n",
|
| 164 |
+
"3.\n",
|
| 165 |
+
"Deans (Ex-officio)\n",
|
| 166 |
+
"1. Dr. Mukesh Kumar Roy\n",
|
| 167 |
+
"Faculty-in-Charge (Student Affairs)\n",
|
| 168 |
+
"mkroy@iiitdmj.ac.in\n",
|
| 169 |
+
"2. Prof. V. K. Gupta\n",
|
| 170 |
+
"--------------------\n"
|
| 171 |
+
]
|
| 172 |
+
}
|
| 173 |
+
],
|
| 174 |
+
"source": [
|
| 175 |
+
"import sys\n",
|
| 176 |
+
"from langchain_community.vectorstores import FAISS\n",
|
| 177 |
+
"from langchain_community.embeddings import HuggingFaceEmbeddings\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"def check_retriever():\n",
|
| 181 |
+
" \"\"\"\n",
|
| 182 |
+
" A standalone script to test the FAISS retriever.\n",
|
| 183 |
+
" \"\"\"\n",
|
| 184 |
+
" \n",
|
| 185 |
+
" # 1. Load the Embedding Model\n",
|
| 186 |
+
" print(f\"Loading embedding model: {MODEL_NAME}...\")\n",
|
| 187 |
+
" try:\n",
|
| 188 |
+
" # This line might show a deprecation warning, which is OK.\n",
|
| 189 |
+
" # It's the same one your agent.py is using.\n",
|
| 190 |
+
" embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)\n",
|
| 191 |
+
" except Exception as e:\n",
|
| 192 |
+
" print(f\"Error loading embeddings: {e}\")\n",
|
| 193 |
+
" print(\"Make sure 'sentence-transformers' is installed: pip install sentence-transformers\")\n",
|
| 194 |
+
" return\n",
|
| 195 |
+
"\n",
|
| 196 |
+
" # # 2. Load the FAISS Vector Store\n",
|
| 197 |
+
" # print(f\"Loading FAISS index from: {DB_FAISS_PATH}...\")\n",
|
| 198 |
+
" # try:\n",
|
| 199 |
+
" # db = FAISS.load_local(\n",
|
| 200 |
+
" # DB_FAISS_PATH, \n",
|
| 201 |
+
" # embeddings, \n",
|
| 202 |
+
" # allow_dangerous_deserialization=True # This is required\n",
|
| 203 |
+
" # )\n",
|
| 204 |
+
" # except Exception as e:\n",
|
| 205 |
+
" # print(f\"Error loading FAISS index: {e}\")\n",
|
| 206 |
+
" # print(\"Be sure you have run 'python ingest.py' successfully first.\")\n",
|
| 207 |
+
" # return\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" retriever = db.as_retriever(search_kwargs={'k': 3})\n",
|
| 210 |
+
" \n",
|
| 211 |
+
" print(\"\\n✅ Retriever is ready.\")\n",
|
| 212 |
+
" print(\" Enter your query to test. Type 'exit' to quit.\")\n",
|
| 213 |
+
" \n",
|
| 214 |
+
" while True:\n",
|
| 215 |
+
" try:\n",
|
| 216 |
+
" query = input(\"\\nQuery> \")\n",
|
| 217 |
+
" if query.lower() == 'exit':\n",
|
| 218 |
+
" break\n",
|
| 219 |
+
" if not query:\n",
|
| 220 |
+
" continue\n",
|
| 221 |
+
" \n",
|
| 222 |
+
" print(f\"\\n--- Retrieving docs for: '{query}' ---\")\n",
|
| 223 |
+
" \n",
|
| 224 |
+
" documents = retriever.invoke(query)\n",
|
| 225 |
+
" \n",
|
| 226 |
+
" if not documents:\n",
|
| 227 |
+
" print(\"\\n!!! No documents found. !!!\")\n",
|
| 228 |
+
" else:\n",
|
| 229 |
+
" for i, doc in enumerate(documents):\n",
|
| 230 |
+
" print(f\"\\n--- Document {i+1} ---\")\n",
|
| 231 |
+
" print(f\"Source: {doc.metadata.get('source', 'N/A')}\")\n",
|
| 232 |
+
" print(f\"Page: {doc.metadata.get('page', 'N/A')}\")\n",
|
| 233 |
+
" print(\"\\nContent:\")\n",
|
| 234 |
+
" print(doc.page_content)\n",
|
| 235 |
+
" print(\"-\" * 20)\n",
|
| 236 |
+
" \n",
|
| 237 |
+
" except Exception as e:\n",
|
| 238 |
+
" print(f\"An error occurred: {e}\")\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"if __name__ == \"__main__\":\n",
|
| 241 |
+
" check_retriever()\n"
|
| 242 |
+
]
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"cell_type": "code",
|
| 246 |
+
"execution_count": 24,
|
| 247 |
+
"id": "45430224",
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"outputs": [],
|
| 250 |
+
"source": [
|
| 251 |
+
"DB_FAISS_PATH = \"vectorstore/faiss_index2\"\n"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"cell_type": "code",
|
| 256 |
+
"execution_count": 25,
|
| 257 |
+
"id": "9488f2a3",
|
| 258 |
+
"metadata": {},
|
| 259 |
+
"outputs": [
|
| 260 |
+
{
|
| 261 |
+
"name": "stdout",
|
| 262 |
+
"output_type": "stream",
|
| 263 |
+
"text": [
|
| 264 |
+
"Successfully created and saved FAISS index to vectorstore/faiss_index2\n"
|
| 265 |
+
]
|
| 266 |
+
}
|
| 267 |
+
],
|
| 268 |
+
"source": [
|
| 269 |
+
"db = FAISS.from_documents(docs, embeddings)\n",
|
| 270 |
+
"db.save_local(DB_FAISS_PATH)\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"print(f\"Successfully created and saved FAISS index to {DB_FAISS_PATH}\")"
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "code",
|
| 277 |
+
"execution_count": null,
|
| 278 |
+
"id": "bef0e8c2",
|
| 279 |
+
"metadata": {},
|
| 280 |
+
"outputs": [],
|
| 281 |
+
"source": []
|
| 282 |
+
}
|
| 283 |
+
],
|
| 284 |
+
"metadata": {
|
| 285 |
+
"kernelspec": {
|
| 286 |
+
"display_name": "venv",
|
| 287 |
+
"language": "python",
|
| 288 |
+
"name": "python3"
|
| 289 |
+
},
|
| 290 |
+
"language_info": {
|
| 291 |
+
"codemirror_mode": {
|
| 292 |
+
"name": "ipython",
|
| 293 |
+
"version": 3
|
| 294 |
+
},
|
| 295 |
+
"file_extension": ".py",
|
| 296 |
+
"mimetype": "text/x-python",
|
| 297 |
+
"name": "python",
|
| 298 |
+
"nbconvert_exporter": "python",
|
| 299 |
+
"pygments_lexer": "ipython3",
|
| 300 |
+
"version": "3.13.7"
|
| 301 |
+
}
|
| 302 |
+
},
|
| 303 |
+
"nbformat": 4,
|
| 304 |
+
"nbformat_minor": 5
|
| 305 |
+
}
|
ingest.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
|
| 3 |
+
from langchain_community.vectorstores import FAISS
|
| 4 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 5 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
DATA_PATH = "data/"
|
| 11 |
+
DB_FAISS_PATH = "vectorstore/faiss_index"
|
| 12 |
+
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2" # Model for embeddings
|
| 13 |
+
|
| 14 |
+
embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)
|
| 15 |
+
|
| 16 |
+
print(f"Loading documents from {DATA_PATH}...")
|
| 17 |
+
loader = DirectoryLoader(
|
| 18 |
+
DATA_PATH,
|
| 19 |
+
glob='*.pdf',
|
| 20 |
+
loader_cls=PyPDFLoader
|
| 21 |
+
)
|
| 22 |
+
documents = loader.load()
|
| 23 |
+
|
| 24 |
+
if not documents:
|
| 25 |
+
print("No PDF documents found. Make sure your PDFs are in the /data folder.")
|
| 26 |
+
exit()
|
| 27 |
+
|
| 28 |
+
print(f"Loaded {len(documents)} PDF document(s).")
|
| 29 |
+
|
| 30 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 31 |
+
chunk_size=300,
|
| 32 |
+
chunk_overlap=200,
|
| 33 |
+
separators=["\n\n", "\n", ".", "!", "?", " ", ""]
|
| 34 |
+
)
|
| 35 |
+
docs = text_splitter.split_documents(documents)
|
| 36 |
+
|
| 37 |
+
print(f"Split into {len(docs)} chunks.")
|
| 38 |
+
|
| 39 |
+
print("Creating and saving FAISS vector store...")
|
| 40 |
+
db = FAISS.from_documents(docs, embeddings)
|
| 41 |
+
db.save_local(DB_FAISS_PATH)
|
| 42 |
+
|
| 43 |
+
print(f"Successfully created and saved FAISS index to {DB_FAISS_PATH}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
langchain
|
| 3 |
+
langchain-community
|
| 4 |
+
langgraph
|
| 5 |
+
langchain-google-genai
|
| 6 |
+
langchain-huggingface
|
| 7 |
+
faiss-cpu
|
| 8 |
+
sentence-transformers
|
| 9 |
+
pypdf
|
| 10 |
+
python-dotenv
|
| 11 |
+
langchain-text-splitters
|
| 12 |
+
pydantic
|
| 13 |
+
tiktoken
|