Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, END | |
| from langchain.chains import RetrievalQA | |
| from typing import TypedDict, Optional | |
| from tools import llm, load_vectorstore, search_tool | |
| # Load your vectorstore | |
| vectorstore = load_vectorstore() | |
| # --- TypedDict to define graph state schema --- | |
| class GraphState(TypedDict): | |
| question: str | |
| pdf_answer: Optional[str] | |
| llm_answer: Optional[str] | |
| web_answer: Optional[str] | |
| # --- LangGraph Node Functions --- | |
| def pdf_qa_node(state: GraphState) -> GraphState: | |
| query = state["question"] | |
| qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever()) | |
| result = qa.run(query) | |
| return {**state, "pdf_answer": result} | |
| def check_pdf_relevance(state: GraphState) -> str: | |
| ans = state.get("pdf_answer", "").lower() | |
| if ( | |
| "i don't know" in ans | |
| or "i don't have information" in ans | |
| or "no relevant" in ans | |
| or "not available" in ans | |
| or len(ans.strip()) < 20 | |
| ): | |
| return "llm_fallback" | |
| return "respond_pdf" | |
| def llm_fallback_node(state: GraphState) -> GraphState: | |
| query = state["question"] | |
| prompt = f"""You are a helpful AI assistant. The user asked a question, and no relevant documents were found. | |
| Try your best to answer this: | |
| Question: {query} | |
| Answer:""" | |
| res = llm.invoke(prompt) | |
| return {**state, "llm_answer": res.content} | |
| def check_llm_confidence(state: GraphState) -> str: | |
| ans = state.get("llm_answer", "").lower() | |
| if "i don't know" in ans or "not sure" in ans or "no idea" in ans: | |
| return "web_search" | |
| return "respond_llm" | |
| def web_search_node(state: GraphState) -> GraphState: | |
| query = state["question"] | |
| result = search_tool(query) | |
| return {**state, "web_answer": result} | |
| def respond_pdf(state: GraphState) -> dict: | |
| print("π Responding from PDF") | |
| return {"answer": state["pdf_answer"]} | |
| def respond_llm(state: GraphState) -> dict: | |
| print("π€ Responding from LLM") | |
| return {"answer": state["llm_answer"]} | |
| def respond_web(state: GraphState) -> dict: | |
| print("π Responding from Web Search") | |
| return {"answer": state["web_answer"]} | |
| # --- Graph Creation Function --- | |
| def create_graph(): | |
| builder = StateGraph(GraphState) # Pass schema | |
| builder.add_node("pdf_qa", pdf_qa_node) | |
| builder.add_node("llm_fallback", llm_fallback_node) | |
| builder.add_node("web_search", web_search_node) | |
| builder.add_node("respond_pdf", respond_pdf) | |
| builder.add_node("respond_llm", respond_llm) | |
| builder.add_node("respond_web", respond_web) | |
| builder.set_entry_point("pdf_qa") | |
| builder.add_conditional_edges("pdf_qa", check_pdf_relevance, { | |
| "respond_pdf": "respond_pdf", | |
| "llm_fallback": "llm_fallback" | |
| }) | |
| builder.add_conditional_edges("llm_fallback", check_llm_confidence, { | |
| "respond_llm": "respond_llm", | |
| "web_search": "web_search" | |
| }) | |
| builder.add_edge("web_search", "respond_web") | |
| # Set all end nodes | |
| builder.add_edge("respond_pdf", END) | |
| builder.add_edge("respond_llm", END) | |
| builder.add_edge("respond_web", END) | |
| return builder.compile() | |