Multi-Rag / src /nodes /main_nodes.py
VashuTheGreat2's picture
Upload folder using huggingface_hub
9c90775 verified
Raw
History Blame Contribute Delete
7.9 kB
import json
import asyncio
import logging
from langchain_core.messages import HumanMessage, SystemMessage
from src.llm.llm_loader import llm
from src.tools import WebSearch
from src.utils.asyncHandler import asyncHandler
from src.entity.config_entity import RetreiverConfig
from src.retrievers.create_retreivers import Retreiver
from src.states.Main_State import (
State,
Orchastrator_output,
Query_generation_output,
Relevance_output,
WebSearchOutput
)
from src.prompts.prompt_templates import (
ORCHESTRATOR_PROMPT,
QUERY_GENERATION_PROMPT,
RELEVANCE_CHECKER_PROMPT,
WEB_SEARCH_PROMPT,
CHAT_PROMPT
)
web_search_tool = WebSearch()
@asyncHandler
async def orchastrator_node(state: State) -> dict:
logging.info("Orchestrator node started")
structured_llm = llm.with_structured_output(Orchastrator_output)
messages = [
SystemMessage(content=ORCHESTRATOR_PROMPT),
*state.get("messages", [])
]
result = structured_llm.invoke(messages)
logging.info(
f"Orchestrator routing decision: require_db_search={result.require_db_search}"
)
return {
"require_db_search": result.require_db_search
}
@asyncHandler
async def query_generation_node(state: State) -> dict:
logging.info("Query generation node started")
structured_llm = llm.with_structured_output(Query_generation_output)
messages = [
SystemMessage(content=QUERY_GENERATION_PROMPT),
*state.get("messages", [])
]
result = structured_llm.invoke(messages)
logging.info(
f"Generated {len(result.queries)} queries"
)
return {
"queries": result.queries
}
@asyncHandler
async def retreiver_node(state: State) -> dict:
logging.info("Retriever node started")
config = RetreiverConfig()
retriever_obj = Retreiver(retreiver_config=config)
paths = state.get("vector_store_file_paths", [])
if not paths and state.get("vector_store_file_path"):
paths = [state["vector_store_file_path"]]
retriever_chain = await retriever_obj.merge_vector_stores(
vector_store_paths=paths
)
if not retriever_chain:
logging.warning("No retriever chain available")
return {"retreived_results": []}
queries = state.get("queries", [])
if not queries:
logging.warning("No queries available for retrieval")
return {"retreived_results": []}
tasks = [
retriever_chain.ainvoke(query)
for query in queries
]
results_list = await asyncio.gather(*tasks)
results = []
seen_contents = set()
for query_results in results_list:
for doc in query_results:
if doc.page_content in seen_contents:
continue
seen_contents.add(doc.page_content)
if "relevance_score" in doc.metadata:
doc.metadata["relevance_score"] = float(
doc.metadata["relevance_score"]
)
results.append(doc)
logging.info(
f"Retriever returned {len(results)} unique documents"
)
return {
"retreived_results": results
}
@asyncHandler
async def is_retreived_data_enough(state: State) -> dict:
logging.info("Relevance checker node started")
retrieved_docs = state.get("retreived_results", [])
docs_content = [
doc.page_content
for doc in retrieved_docs
]
user_query = state.get("messages", [])[-1].content
prompt = RELEVANCE_CHECKER_PROMPT.format(
user_query=user_query,
retreived_docs_content=docs_content
)
structured_llm = llm.with_structured_output(
Relevance_output
)
result = structured_llm.invoke(
[
SystemMessage(content=prompt)
]
)
logging.info(
f"Relevance decision: {result.relevance}"
)
return {
"relevance": result.relevance
}
@asyncHandler
async def web_search_node(state: State) -> dict:
logging.info("Web search node started")
query = state.get("messages", [])[-1].content
structured_llm = llm.with_structured_output(
WebSearchOutput
)
generated_queries = structured_llm.invoke(
[
SystemMessage(
content=WEB_SEARCH_PROMPT.format(
query=query
)
)
]
)
search_tasks = [
web_search_tool.search.ainvoke(q)
for q in generated_queries.queries
]
raw_results = await asyncio.gather(*search_tasks)
results = [
item
for sublist in raw_results
for item in (sublist if isinstance(sublist, list) else [sublist])
]
logging.info(
f"Web search returned {len(results)} results"
)
return {
"web_search_results": results
}
@asyncHandler
async def document_refiner(state: State) -> dict:
logging.info("Document refiner node started")
return {
"refined_results": state.get(
"retreived_results",
[]
)
}
@asyncHandler
async def get_chat_node_content(state: State) -> dict:
logging.info("Preparing multimodal context")
query = state.get("messages", [])[-1].content
chunks = state.get(
"refined_results",
state.get("retreived_results", [])
)
prompt_text = f"""
Based on the following documents answer the question.
Question:
{query}
CONTENT:
"""
for index, chunk in enumerate(chunks):
prompt_text += f"\n--- Document {index + 1} ---\n"
if "original_content" not in chunk.metadata:
prompt_text += chunk.page_content
continue
original_data = json.loads(
chunk.metadata["original_content"]
)
raw_text = original_data.get(
"raw_text",
""
)
if raw_text:
prompt_text += f"\nTEXT:\n{raw_text}\n"
tables = original_data.get(
"tables_html",
[]
)
if tables:
prompt_text += "\nTABLES:\n"
for table in tables:
prompt_text += f"{table}\n"
web_results = state.get(
"web_search_results",
[]
)
if web_results:
prompt_text += "\nWEB SEARCH RESULTS:\n"
for result in web_results:
prompt_text += f"{result}\n"
message_content = [
{
"type": "text",
"text": prompt_text
}
]
for chunk in chunks:
if "original_content" not in chunk.metadata:
continue
original_data = json.loads(
chunk.metadata["original_content"]
)
images = original_data.get(
"images_base64",
[]
)
for image in images:
message_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image}"
}
}
)
response = llm.invoke(
[
HumanMessage(content=message_content)
]
)
logging.info(
"Document context prepared successfully"
)
return {
"docs_feed_to_llm": response.content
}
@asyncHandler
async def chat_node(state: State) -> dict:
logging.info("Chat node started")
prompt = [
SystemMessage(content=CHAT_PROMPT),
*state.get("messages", [])
]
docs_context = state.get(
"docs_feed_to_llm"
)
if docs_context:
prompt.append(
HumanMessage(
content=f"Context:\n{docs_context}"
)
)
response = llm.invoke(prompt)
logging.info("Final response generated")
return {
"messages": [response],
"ai_response": response.content
}