Spaces:
Sleeping
Sleeping
| # rag_agent_app/backend/agent.py | |
| import os | |
| from typing import List, Literal, TypedDict, Annotated | |
| from langchain_core.messages import AIMessage, BaseMessage, HumanMessage | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.tools import tool | |
| from langchain_groq import ChatGroq | |
| from langchain_tavily import TavilySearch | |
| from pydantic import BaseModel, Field | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.runnables import RunnableConfig # <-- NEW LINE ADDED HERE | |
| # Import API keys from config | |
| from config import GROQ_API_KEY, TAVILY_API_KEY | |
| from vectorstore import get_retriever | |
| # --- Tools --- | |
| os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY | |
| tavily = TavilySearch(max_results=3, topic="general") | |
| def web_search_tool(query: str) -> str: | |
| """Up-to-date web info via Tavily""" | |
| try: | |
| result = tavily.invoke({"query": query}) | |
| if isinstance(result, dict) and 'results' in result: | |
| formatted_results = [] | |
| for item in result['results']: | |
| title = item.get('title', 'No title') | |
| content = item.get('content', 'No content') | |
| url = item.get('url', '') | |
| formatted_results.append(f"Title: {title}\nContent: {content}\nURL: {url}") | |
| return "\n\n".join(formatted_results) if formatted_results else "No results found" | |
| else: | |
| return str(result) | |
| except Exception as e: | |
| return f"WEB_ERROR::{e}" | |
| def rag_search_tool(query: str) -> str: | |
| """Top-K chunks from KB (empty string if none)""" | |
| try: | |
| retriever_instance = get_retriever() | |
| docs = retriever_instance.invoke(query, k=5) # Increased from 3 to 5 | |
| return "\n\n".join(d.page_content for d in docs) if docs else "" | |
| except Exception as e: | |
| return f"RAG_ERROR::{e}" | |
| # --- Pydantic schemas for structured output --- | |
| class RouteDecision(BaseModel): | |
| route: Literal["rag", "web", "answer", "end"] | |
| reply: str | None = Field(None, description="Filled only when route == 'end'") | |
| class RagJudge(BaseModel): | |
| verdict: Literal["yes", "no"] = Field(..., description="Set to 'yes' if retrieved info is sufficient, 'no' otherwise.") | |
| # --- LLM instances with structured output where needed --- | |
| os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
| router_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0).with_structured_output(RouteDecision) | |
| judge_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0).with_structured_output(RagJudge) | |
| answer_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0.7) | |
| # --- Shared state type --- | |
| class AgentState(TypedDict, total=False): | |
| messages: Annotated[List[BaseMessage], add_messages] | |
| route: Literal["rag", "web", "answer", "end"] | |
| rag: str | |
| web: str | |
| web_search_enabled: bool | |
| # --- Node 1: router (decision) --- | |
| def router_node(state: AgentState,config : RunnableConfig) -> AgentState: | |
| print("\n--- Entering router_node ---") | |
| query = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "") | |
| # MODIFIED: Get web_search_enabled directly from the config | |
| web_search_enabled = config.get("configurable", {}).get("web_search_enabled", True) # <-- CHANGED LINE | |
| print(f"Router received web search info : {web_search_enabled}") | |
| system_prompt = ( | |
| "You are an intelligent routing agent designed to direct user queries to the most appropriate tool." | |
| "Your primary goal is to provide accurate and relevant information by selecting the best source." | |
| "Prioritize using the **internal knowledge base (RAG)** for factual information that is likely " | |
| "to be contained within pre-uploaded documents or for common, well-established facts." | |
| ) | |
| if web_search_enabled: | |
| system_prompt += ( | |
| "You **CAN** use web search for queries that require very current, real-time, or broad general knowledge " | |
| "that is unlikely to be in a specific, static knowledge base (e.g., today's news, live data, very recent events)." | |
| "\n\nChoose one of the following routes:" | |
| "\n- 'rag': For queries about specific entities, historical facts, product details, procedures, or any information that would typically be found in a curated document collection (e.g., 'What is X?', 'How does Y work?', 'Explain Z policy')." | |
| "\n- 'web': For queries about current events, live data, very recent news, or broad general knowledge that requires up-to-date internet access (e.g., 'Who won the election yesterday?', 'What is the weather in London?', 'Latest news on technology')." | |
| ) | |
| else: | |
| system_prompt += ( | |
| "**Web search is currently DISABLED.** You **MUST NOT** choose the 'web' route." | |
| "If a query would normally require web search, you should attempt to answer it using RAG (if applicable) or directly from your general knowledge." | |
| "\n\nChoose one of the following routes:" | |
| "\n- 'rag': For queries about specific entities, historical facts, product details, procedures, or any information that would typically be found in a curated document collection, AND for queries that would normally go to web search but web search is disabled." | |
| "\n- 'answer': For very simple, direct questions you can answer without any external lookup (e.g., 'What is your name?')." | |
| ) | |
| system_prompt += ( | |
| "\n- 'answer': For very simple, direct questions you can answer without any external lookup (e.g., 'What is your name?')." | |
| "\n- 'end': For pure greetings or small-talk where no factual answer is expected (e.g., 'Hi', 'How are you?'). If choosing 'end', you MUST provide a 'reply'." | |
| "\n\nExample routing decisions:" | |
| "\n- User: 'What are the treatment of diabetes?' -> Route: 'rag' (Factual knowledge, likely in KB)." | |
| "\n- User: 'What is the capital of France?' -> Route: 'rag' (Common knowledge, can be in KB or answered directly if LLM knows)." | |
| "\n- User: 'Who won the NBA finals last night?' -> Route: 'web' (Current event, requires live data)." | |
| "\n- User: 'How do I submit an expense report?' -> Route: 'rag' (Internal procedure)." | |
| "\n- User: 'Tell me about quantum computing.' -> Route: 'rag' (Foundational knowledge can be in KB. If KB is sparse, judge will route to web if enabled)." | |
| "\n- User: 'Hello there!' -> Route: 'end', reply='Hello! How can I assist you today?'" | |
| ) | |
| messages = [ | |
| ("system", system_prompt), | |
| ("user", query) | |
| ] | |
| result: RouteDecision = router_llm.invoke(messages) | |
| initial_router_decision = result.route # Store the LLM's raw decision | |
| router_override_reason = None | |
| # NEW LOGIC: Override router decision if web search is disabled and LLM chose 'web' | |
| if not web_search_enabled and result.route == "web": | |
| # If web search is disabled, force it to try RAG instead | |
| result.route = "rag" | |
| router_override_reason = "Web search disabled by user; redirected to RAG." | |
| print(f"Router decision overridden: changed from 'web' to 'rag' because web search is disabled.") | |
| print(f"Router final decision: {result.route}, Reply (if 'end'): {result.reply}") | |
| out = { | |
| "messages": state["messages"], | |
| "route": result.route, | |
| "web_search_enabled": web_search_enabled # Pass the flag along in the state | |
| } | |
| if router_override_reason: # Add override info for tracing | |
| out["initial_router_decision"] = initial_router_decision | |
| out["router_override_reason"] = router_override_reason | |
| if result.route == "end": | |
| out["messages"] = [AIMessage(content=result.reply or "Hello!")] | |
| print("--- Exiting router_node ---") | |
| return out | |
| # --- Node 2: RAG lookup --- | |
| def rag_node(state: AgentState,config:RunnableConfig) -> AgentState: | |
| print("\n--- Entering rag_node ---") | |
| query = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "") | |
| # MODIFIED: Get web_search_enabled directly from the config | |
| web_search_enabled = config.get("configurable", {}).get("web_search_enabled", True) # <-- CHANGED LINE | |
| print(f"Router received web search info : {web_search_enabled}") | |
| print(f"RAG query: {query}") | |
| chunks = rag_search_tool.invoke(query) | |
| if chunks.startswith("RAG_ERROR::"): | |
| print(f"RAG Error: {chunks}. Checking web search enabled status.") | |
| # If RAG fails, and web search is enabled, try web. Otherwise, go to answer. | |
| next_route = "web" if web_search_enabled else "answer" | |
| return {**state, "rag": "", "route": next_route} | |
| if not chunks: | |
| print("No RAG chunks retrieved. Skipping judge.") | |
| next_route = "web" if web_search_enabled else "answer" | |
| return {**state, "rag": "", "route": next_route, "web_search_enabled": web_search_enabled} | |
| judge_messages = [ | |
| ("system", ( | |
| "You are a judge evaluating if the retrieved information is sufficient and relevant " | |
| "to fully and accurately answer the user's question. " | |
| "Consider if the retrieved text directly addresses the question's core and provides enough detail." | |
| "If the information is incomplete, vague, or doesn't directly answer the question, it is NOT sufficient." | |
| "\n\nRespond with 'yes' for sufficient, 'no' for insufficient." | |
| )), | |
| ("user", f"Question: {query}\n\nRetrieved info: {chunks}\n\nIs this sufficient to answer the question?") | |
| ] | |
| result: RagJudge = judge_llm.invoke(judge_messages) | |
| is_sufficient = result.verdict == "yes" | |
| print(f"RAG Judge verdict: {result.verdict} (is_sufficient={is_sufficient})") | |
| print("--- Exiting rag_node ---") | |
| # NEW LOGIC: Decide next route based on sufficiency AND web_search_enabled | |
| if is_sufficient: | |
| next_route = "answer" | |
| else: | |
| next_route = "web" if web_search_enabled else "answer" # If not sufficient, only go to web if enabled | |
| print(f"RAG not sufficient. Web search enabled: {web_search_enabled}. Next route: {next_route}") | |
| return { | |
| **state, | |
| "rag": chunks, | |
| "route": next_route, | |
| "web_search_enabled": web_search_enabled # Pass the flag along | |
| } | |
| # --- Node 3: web search --- | |
| def web_node(state: AgentState,config:RunnableConfig) -> AgentState: | |
| print("\n--- Entering web_node ---") | |
| query = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "") | |
| # Check if web search is actually enabled before performing it | |
| # MODIFIED: Get web_search_enabled directly from the config | |
| web_search_enabled = config.get("configurable", {}).get("web_search_enabled", True) # <-- CHANGED LINE | |
| print(f"Router received web search info : {web_search_enabled}") | |
| if not web_search_enabled: | |
| print("Web search node entered but web search is disabled. Skipping actual search.") | |
| return {**state, "web": "Web search was disabled by the user.", "route": "answer"} | |
| print(f"Web search query: {query}") | |
| snippets = web_search_tool.invoke(query) | |
| if snippets.startswith("WEB_ERROR::"): | |
| print(f"Web Error: {snippets}. Proceeding to answer with limited info.") | |
| return {**state, "web": "", "route": "answer"} | |
| print(f"Web snippets retrieved: {snippets[:200]}...") | |
| print("--- Exiting web_node ---") | |
| return {**state, "web": snippets, "route": "answer"} | |
| # --- Node 4: final answer --- | |
| def answer_node(state: AgentState) -> AgentState: | |
| print("\n--- Entering answer_node ---") | |
| user_q = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "") | |
| ctx_parts = [] | |
| if state.get("rag"): | |
| ctx_parts.append("Knowledge Base Information:\n" + state["rag"]) | |
| if state.get("web"): | |
| # If web search was disabled, the 'web' field might contain a message like "Web search was disabled..." | |
| # We should only include actual search results here. | |
| if state["web"] and not state["web"].startswith("Web search was disabled"): | |
| ctx_parts.append("Web Search Results:\n" + state["web"]) | |
| context = "\n\n".join(ctx_parts) | |
| if not context.strip(): | |
| context = "No external context was available for this query. Try to answer based on general knowledge if possible." | |
| prompt = f"""Please answer the user's question using the provided context. | |
| If the context is empty or irrelevant, try to answer based on your general knowledge. | |
| Question: {user_q} | |
| Context: | |
| {context} | |
| Provide a helpful, accurate, and concise response based on the available information.""" | |
| print(f"Prompt sent to answer_llm: {prompt[:500]}...") | |
| ans = answer_llm.invoke(state["messages"] + [HumanMessage(content=prompt)]).content | |
| print(f"Final answer generated: {ans[:200]}...") | |
| print("--- Exiting answer_node ---") | |
| return { | |
| **state, | |
| "messages": [AIMessage(content=ans)] | |
| } | |
| # --- Routing helpers --- | |
| def from_router(st: AgentState) -> Literal["rag", "web", "answer", "end"]: | |
| return st["route"] | |
| def after_rag(st: AgentState) -> Literal["answer", "web"]: | |
| return st["route"] | |
| def after_web(_) -> Literal["answer"]: | |
| return "answer" | |
| # --- Build graph --- | |
| def build_agent(): | |
| """Builds and compiles the LangGraph agent.""" | |
| g = StateGraph(AgentState) | |
| g.add_node("router", router_node) | |
| g.add_node("rag_lookup", rag_node) | |
| g.add_node("web_search", web_node) | |
| g.add_node("answer", answer_node) | |
| g.set_entry_point("router") | |
| g.add_conditional_edges( | |
| "router", | |
| from_router, | |
| { | |
| "rag": "rag_lookup", | |
| "web": "web_search", | |
| "answer": "answer", | |
| "end": END | |
| } | |
| ) | |
| g.add_conditional_edges( | |
| "rag_lookup", | |
| after_rag, | |
| { | |
| "answer": "answer", | |
| "web": "web_search" | |
| } | |
| ) | |
| g.add_edge("web_search", "answer") | |
| g.add_edge("answer", END) | |
| agent = g.compile(checkpointer=MemorySaver()) | |
| return agent | |
| rag_agent = build_agent() |