File size: 7,195 Bytes
9cc7f8d bb05158 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | from typing import TypedDict , Annotated , List
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage , HumanMessage
from langchain_openai import ChatOpenAI
import os
from src.retrieval import Retriever
import os
from tavily import TavilyClient
from dotenv import load_dotenv
from langgraph.graph import StateGraph, START ,END
from langgraph.checkpoint.postgres import PostgresSaver
from psycopg_pool import ConnectionPool
load_dotenv()
class State(TypedDict) :
messages : Annotated[list , add_messages]
context : List[dict]
rewritten_query : str
user_id : str
web_search_needed : bool
retry : int
llm = ChatOpenAI(
model="openai/gpt-4o-mini",
openai_api_key=os.getenv("OPENROUTER_API_KEY"),
openai_api_base="https://openrouter.ai/api/v1",
temperature=0
)
retriever = Retriever()
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
def rewrite_node(state : State) :
messages = state['messages']
# 1. Filter to only get the human's messages
user_msg = [msg for msg in messages if isinstance(msg , HumanMessage)]
# 2. Extract the actual text
latest_ques = user_msg[-1].content
history = "\n".join([msg.content for msg in user_msg[:-1]])
# 3. Set the strict system rules
system_prompt = SystemMessage(content="""You are an expert search query generator for a vector database.
Your ONLY job is to convert the user's latest input into a single, highly optimized search string.
You will receive a sequence of the user's previous questions, followed by their newest input.
CRITICAL RULES:
1. TRACK THE TRAIN OF THOUGHT: If the latest input uses pronouns (it, they, this) or is a fragment (e.g., "What about the budget?"), identify the core noun from the previous questions and substitute it.
2. NO CONVERSATIONAL FILLER: Do not answer the question. Do not explain your reasoning.
3. FORMAT: Output only the raw search keywords. No commas, no bullet points.
Example Input:
Chat History:
What is the main objective of Project Chronos?
Who is the lead engineer?
Latest User Input: What is his total budget for Q4?
Example Output: Project Chronos lead engineer budget
""")
# 4. FIX: Package the history and question into a proper HumanMessage object
human_prompt = HumanMessage(content=f"Chat History: {history}\n\nLatest User Input: {latest_ques}\n\nGenerate the concise search query now:")
# 5. FIX: Combine them as a valid list of Message objects
final_msg = [system_prompt, human_prompt]
# 6. Invoke the LLM
response = llm.invoke(final_msg)
print("\n" + "="*60, flush=True)
print(f"\n ReQuery : \n{response.content} \n", flush=True)
print("="*60 + "\n", flush=True)
return {'rewritten_query' : response.content}
def retrieve_node(state : State) :
user_id = state['user_id']
re_query = state['rewritten_query']
context = retriever.retrieve(re_query , user_id)
return{'context' : context}
def answer_node(state : State) :
messages = state['messages']
context = state['context']
retry = state.get('retry' , 0)
context_text = ""
if not context:
context_text = "No relevant context found in the database for this specific query."
else:
for i, chunk in enumerate(context):
context_text += f"\n--- Document Chunk {i+1} ---\n"
context_text += f"Source: {chunk.get('source', 'Unknown')}\n"
context_text += f"Pages: {chunk.get('pages', 'N/A')}\n"
context_text += f"Section: {chunk.get('section', 'N/A')}\n"
context_text += f"Content: {chunk.get('text', '')}\n"
print("\n" + "="*60, flush=True)
print(f"\n\nCONTEXT TEXT :/n/n{context_text}", flush=True)
print("="*60 + "\n", flush=True)
if retry<1 :
system_prompt = SystemMessage(content=f"""
You are an advanced enterprise RAG assistant. Your job is to answer the user's latest question
by strictly analyzing the conversation history and the provided document chunks below.
CRITICAL RULES:
1. Base your answer ONLY on the text snippets provided in the Context section below. Do not assume or extrapolate.
2. If the context does not contain the answer, or if the context is irrelevant to the question,
you must reply with exactly this phrase and absolutely nothing else: FALLBACK_TO_WEB_SEARCH
3. You MUST inline cite your sources whenever you use information from a chunk.
Format your citations cleanly at the end of sentences like this: [Source: file.pdf, Page: X].
CONTEXT DATA:
{context_text}
""")
else :
system_prompt = f"""
You are an advanced enterprise RAG assistant. Your job is to answer the user's latest question
by strictly analyzing the conversation history and the provided document chunks below.
These chunks now include both internal documents and live web search results.
CRITICAL RULES:
1. Base your answer ONLY on the text snippets provided in the Context section below. Do not assume or extrapolate.
2. DO NOT ask for another web search. If the answer is still not found in the provided context, you must politely inform the user that the information is unavailable.
3. You MUST inline cite your sources whenever you use information from a chunk.
Format your citations cleanly at the end of sentences like this: [Source: file.pdf, Page: X] or [Source: website_url].
CONTEXT DATA:
{context_text}
"""
final_msg = [system_prompt] + messages
response = llm.invoke(final_msg)
if response.content.strip() == "FALLBACK_TO_WEB_SEARCH":
return {"web_search_needed": True}
else:
return {"messages": [response],
"web_search_needed": False}
def routing(state : State) :
if state["web_search_needed"] :
return "web_search_node"
else:
return "END"
def web_search_node(state : State) :
re_query = state['rewritten_query']
context = state['context']
retry = state.get('retry' , 0)
response = tavily_client.search(query=re_query , max_results=3)
results = response['results']
web_context = []
for res in results :
web_context.append({
"text": res.get("content", ""),
"source": res.get("url", "Live Web Search"),
"pages": "N/A",
"section": "Internet Result"
})
combined = context + web_context
return {'context' : combined , 'retry' : retry+1}
workflow = StateGraph(State)
workflow.add_node("rewrite_node" , rewrite_node)
workflow.add_node("retrieve_node" , retrieve_node)
workflow.add_node("answer_node" , answer_node)
workflow.add_node("web_search_node" , web_search_node)
workflow.add_edge(START , "rewrite_node")
workflow.add_edge("rewrite_node" , "retrieve_node")
workflow.add_edge("retrieve_node" , "answer_node")
workflow.add_conditional_edges(
"answer_node",
routing,
{"web_search_node": "web_search_node",
"END": END})
workflow.add_edge("web_search_node" , "answer_node") |