Spaces:
Sleeping
Sleeping
File size: 14,201 Bytes
4184ffc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
# rag_agent_app/backend/agent.py
import os
from typing import List, Literal, TypedDict, Annotated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langgraph.graph.message import add_messages
from langchain_core.tools import tool
from langchain_groq import ChatGroq
from langchain_tavily import TavilySearch
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.runnables import RunnableConfig # <-- NEW LINE ADDED HERE
# Import API keys from config
from config import GROQ_API_KEY, TAVILY_API_KEY
from vectorstore import get_retriever
# --- Tools ---
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
tavily = TavilySearch(max_results=3, topic="general")
@tool
def web_search_tool(query: str) -> str:
"""Up-to-date web info via Tavily"""
try:
result = tavily.invoke({"query": query})
if isinstance(result, dict) and 'results' in result:
formatted_results = []
for item in result['results']:
title = item.get('title', 'No title')
content = item.get('content', 'No content')
url = item.get('url', '')
formatted_results.append(f"Title: {title}\nContent: {content}\nURL: {url}")
return "\n\n".join(formatted_results) if formatted_results else "No results found"
else:
return str(result)
except Exception as e:
return f"WEB_ERROR::{e}"
@tool
def rag_search_tool(query: str) -> str:
"""Top-K chunks from KB (empty string if none)"""
try:
retriever_instance = get_retriever()
docs = retriever_instance.invoke(query, k=5) # Increased from 3 to 5
return "\n\n".join(d.page_content for d in docs) if docs else ""
except Exception as e:
return f"RAG_ERROR::{e}"
# --- Pydantic schemas for structured output ---
class RouteDecision(BaseModel):
route: Literal["rag", "web", "answer", "end"]
reply: str | None = Field(None, description="Filled only when route == 'end'")
class RagJudge(BaseModel):
verdict: Literal["yes", "no"] = Field(..., description="Set to 'yes' if retrieved info is sufficient, 'no' otherwise.")
# --- LLM instances with structured output where needed ---
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
router_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0).with_structured_output(RouteDecision)
judge_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0).with_structured_output(RagJudge)
answer_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0.7)
# --- Shared state type ---
class AgentState(TypedDict, total=False):
messages: Annotated[List[BaseMessage], add_messages]
route: Literal["rag", "web", "answer", "end"]
rag: str
web: str
web_search_enabled: bool
# --- Node 1: router (decision) ---
def router_node(state: AgentState,config : RunnableConfig) -> AgentState:
print("\n--- Entering router_node ---")
query = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "")
# MODIFIED: Get web_search_enabled directly from the config
web_search_enabled = config.get("configurable", {}).get("web_search_enabled", True) # <-- CHANGED LINE
print(f"Router received web search info : {web_search_enabled}")
system_prompt = (
"You are an intelligent routing agent designed to direct user queries to the most appropriate tool."
"Your primary goal is to provide accurate and relevant information by selecting the best source."
"Prioritize using the **internal knowledge base (RAG)** for factual information that is likely "
"to be contained within pre-uploaded documents or for common, well-established facts."
)
if web_search_enabled:
system_prompt += (
"You **CAN** use web search for queries that require very current, real-time, or broad general knowledge "
"that is unlikely to be in a specific, static knowledge base (e.g., today's news, live data, very recent events)."
"\n\nChoose one of the following routes:"
"\n- 'rag': For queries about specific entities, historical facts, product details, procedures, or any information that would typically be found in a curated document collection (e.g., 'What is X?', 'How does Y work?', 'Explain Z policy')."
"\n- 'web': For queries about current events, live data, very recent news, or broad general knowledge that requires up-to-date internet access (e.g., 'Who won the election yesterday?', 'What is the weather in London?', 'Latest news on technology')."
)
else:
system_prompt += (
"**Web search is currently DISABLED.** You **MUST NOT** choose the 'web' route."
"If a query would normally require web search, you should attempt to answer it using RAG (if applicable) or directly from your general knowledge."
"\n\nChoose one of the following routes:"
"\n- 'rag': For queries about specific entities, historical facts, product details, procedures, or any information that would typically be found in a curated document collection, AND for queries that would normally go to web search but web search is disabled."
"\n- 'answer': For very simple, direct questions you can answer without any external lookup (e.g., 'What is your name?')."
)
system_prompt += (
"\n- 'answer': For very simple, direct questions you can answer without any external lookup (e.g., 'What is your name?')."
"\n- 'end': For pure greetings or small-talk where no factual answer is expected (e.g., 'Hi', 'How are you?'). If choosing 'end', you MUST provide a 'reply'."
"\n\nExample routing decisions:"
"\n- User: 'What are the treatment of diabetes?' -> Route: 'rag' (Factual knowledge, likely in KB)."
"\n- User: 'What is the capital of France?' -> Route: 'rag' (Common knowledge, can be in KB or answered directly if LLM knows)."
"\n- User: 'Who won the NBA finals last night?' -> Route: 'web' (Current event, requires live data)."
"\n- User: 'How do I submit an expense report?' -> Route: 'rag' (Internal procedure)."
"\n- User: 'Tell me about quantum computing.' -> Route: 'rag' (Foundational knowledge can be in KB. If KB is sparse, judge will route to web if enabled)."
"\n- User: 'Hello there!' -> Route: 'end', reply='Hello! How can I assist you today?'"
)
messages = [
("system", system_prompt),
("user", query)
]
result: RouteDecision = router_llm.invoke(messages)
initial_router_decision = result.route # Store the LLM's raw decision
router_override_reason = None
# NEW LOGIC: Override router decision if web search is disabled and LLM chose 'web'
if not web_search_enabled and result.route == "web":
# If web search is disabled, force it to try RAG instead
result.route = "rag"
router_override_reason = "Web search disabled by user; redirected to RAG."
print(f"Router decision overridden: changed from 'web' to 'rag' because web search is disabled.")
print(f"Router final decision: {result.route}, Reply (if 'end'): {result.reply}")
out = {
"messages": state["messages"],
"route": result.route,
"web_search_enabled": web_search_enabled # Pass the flag along in the state
}
if router_override_reason: # Add override info for tracing
out["initial_router_decision"] = initial_router_decision
out["router_override_reason"] = router_override_reason
if result.route == "end":
out["messages"] = [AIMessage(content=result.reply or "Hello!")]
print("--- Exiting router_node ---")
return out
# --- Node 2: RAG lookup ---
def rag_node(state: AgentState,config:RunnableConfig) -> AgentState:
print("\n--- Entering rag_node ---")
query = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "")
# MODIFIED: Get web_search_enabled directly from the config
web_search_enabled = config.get("configurable", {}).get("web_search_enabled", True) # <-- CHANGED LINE
print(f"Router received web search info : {web_search_enabled}")
print(f"RAG query: {query}")
chunks = rag_search_tool.invoke(query)
if chunks.startswith("RAG_ERROR::"):
print(f"RAG Error: {chunks}. Checking web search enabled status.")
# If RAG fails, and web search is enabled, try web. Otherwise, go to answer.
next_route = "web" if web_search_enabled else "answer"
return {**state, "rag": "", "route": next_route}
if not chunks:
print("No RAG chunks retrieved. Skipping judge.")
next_route = "web" if web_search_enabled else "answer"
return {**state, "rag": "", "route": next_route, "web_search_enabled": web_search_enabled}
judge_messages = [
("system", (
"You are a judge evaluating if the retrieved information is sufficient and relevant "
"to fully and accurately answer the user's question. "
"Consider if the retrieved text directly addresses the question's core and provides enough detail."
"If the information is incomplete, vague, or doesn't directly answer the question, it is NOT sufficient."
"\n\nRespond with 'yes' for sufficient, 'no' for insufficient."
)),
("user", f"Question: {query}\n\nRetrieved info: {chunks}\n\nIs this sufficient to answer the question?")
]
result: RagJudge = judge_llm.invoke(judge_messages)
is_sufficient = result.verdict == "yes"
print(f"RAG Judge verdict: {result.verdict} (is_sufficient={is_sufficient})")
print("--- Exiting rag_node ---")
# NEW LOGIC: Decide next route based on sufficiency AND web_search_enabled
if is_sufficient:
next_route = "answer"
else:
next_route = "web" if web_search_enabled else "answer" # If not sufficient, only go to web if enabled
print(f"RAG not sufficient. Web search enabled: {web_search_enabled}. Next route: {next_route}")
return {
**state,
"rag": chunks,
"route": next_route,
"web_search_enabled": web_search_enabled # Pass the flag along
}
# --- Node 3: web search ---
def web_node(state: AgentState,config:RunnableConfig) -> AgentState:
print("\n--- Entering web_node ---")
query = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "")
# Check if web search is actually enabled before performing it
# MODIFIED: Get web_search_enabled directly from the config
web_search_enabled = config.get("configurable", {}).get("web_search_enabled", True) # <-- CHANGED LINE
print(f"Router received web search info : {web_search_enabled}")
if not web_search_enabled:
print("Web search node entered but web search is disabled. Skipping actual search.")
return {**state, "web": "Web search was disabled by the user.", "route": "answer"}
print(f"Web search query: {query}")
snippets = web_search_tool.invoke(query)
if snippets.startswith("WEB_ERROR::"):
print(f"Web Error: {snippets}. Proceeding to answer with limited info.")
return {**state, "web": "", "route": "answer"}
print(f"Web snippets retrieved: {snippets[:200]}...")
print("--- Exiting web_node ---")
return {**state, "web": snippets, "route": "answer"}
# --- Node 4: final answer ---
def answer_node(state: AgentState) -> AgentState:
print("\n--- Entering answer_node ---")
user_q = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "")
ctx_parts = []
if state.get("rag"):
ctx_parts.append("Knowledge Base Information:\n" + state["rag"])
if state.get("web"):
# If web search was disabled, the 'web' field might contain a message like "Web search was disabled..."
# We should only include actual search results here.
if state["web"] and not state["web"].startswith("Web search was disabled"):
ctx_parts.append("Web Search Results:\n" + state["web"])
context = "\n\n".join(ctx_parts)
if not context.strip():
context = "No external context was available for this query. Try to answer based on general knowledge if possible."
prompt = f"""Please answer the user's question using the provided context.
If the context is empty or irrelevant, try to answer based on your general knowledge.
Question: {user_q}
Context:
{context}
Provide a helpful, accurate, and concise response based on the available information."""
print(f"Prompt sent to answer_llm: {prompt[:500]}...")
ans = answer_llm.invoke(state["messages"] + [HumanMessage(content=prompt)]).content
print(f"Final answer generated: {ans[:200]}...")
print("--- Exiting answer_node ---")
return {
**state,
"messages": [AIMessage(content=ans)]
}
# --- Routing helpers ---
def from_router(st: AgentState) -> Literal["rag", "web", "answer", "end"]:
return st["route"]
def after_rag(st: AgentState) -> Literal["answer", "web"]:
return st["route"]
def after_web(_) -> Literal["answer"]:
return "answer"
# --- Build graph ---
def build_agent():
"""Builds and compiles the LangGraph agent."""
g = StateGraph(AgentState)
g.add_node("router", router_node)
g.add_node("rag_lookup", rag_node)
g.add_node("web_search", web_node)
g.add_node("answer", answer_node)
g.set_entry_point("router")
g.add_conditional_edges(
"router",
from_router,
{
"rag": "rag_lookup",
"web": "web_search",
"answer": "answer",
"end": END
}
)
g.add_conditional_edges(
"rag_lookup",
after_rag,
{
"answer": "answer",
"web": "web_search"
}
)
g.add_edge("web_search", "answer")
g.add_edge("answer", END)
agent = g.compile(checkpointer=MemorySaver())
return agent
rag_agent = build_agent() |