RAG_APP / src /agents /main_agent.py
sxid003's picture
Upload 83 files
3107242 verified
from langgraph.graph import StateGraph, END
from typing import TypedDict, List
from langchain_core.messages import BaseMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from src.mcp.client import summary_agent
from src.agents.rag_agent import rag_graph, invocation_state
from src.models.llm_wrapper import GeminiWrapper
import re
from src.utils.helpers import clean_agent_output, load_prompt_template
from src.configs.config import LOG_DIR
import logging
import os
# LOGGING SETUP
LOG_FILE = os.path.join(LOG_DIR, "Agents.log")
logging.basicConfig(
filename=LOG_FILE,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
llm = GeminiWrapper()
memory = MemorySaver()
# STATE
class MessagesState(TypedDict):
messages: List[BaseMessage]
# ROUTER
def route_query(state: dict) -> str:
query = state["messages"][-1].content.lower()
'''# Déterminer le type d'invocation
if "api" in query: # Juste un exemple, adapte selon tes besoins
invocation_state.invocation_type = "api"
else:
invocation_state.invocation_type = "chatbot"'''
# Inline classification prompt
prompt = f"""
You are a query router. Classify the following user query into one of two categories:
- "rag": if the query requires retrieving factual or external information.
- "summary": if the query asks to summarize document or youtube session from previous messages.
Respond with only one word: either "rag" or "summary".
User query: "{query}"
"""
try:
response = llm.generate(prompt)
except Exception as e:
logging.error(f"Error generating routing response: {e}")
return "rag" # fallback routing
logging.info(f"Raw LLM response: {response!r}")
decision = response.strip().lower().split()[0]
logging.info(f"Routing decision for query '{query}': {decision}")
if decision not in ("summary", "rag"):
logging.warning(f"Unexpected routing result '{decision}', defaulting to 'rag'")
return "rag"
return decision
# MCP NODE
async def mcp_node(state: MessagesState) -> dict:
user_message = state["messages"][-1].content
logging.info(f"MCP agent input: {user_message}")
try:
raw_result = await summary_agent.run(user_message)
logging.info(f"MCP raw output: {raw_result}")
clean_result = clean_agent_output(raw_result)
logging.info(f"MCP cleaned output: {clean_result}")
return {"messages": [AIMessage(content=clean_result)]}
except Exception as e:
logging.error(f"Error in MCP agent: {e}")
return {"messages": [AIMessage(content="Une erreur est survenue avec l’agent MCP.")]}
# GRAPH DEFINITION
builder = StateGraph(MessagesState)
builder.add_node("router", lambda state: {"messages": state["messages"]}) # Dummy pass-through node
builder.add_node("rag", rag_graph) # RAG graph
builder.add_node("summary", mcp_node) # MCP node
builder.set_entry_point("router")
builder.add_conditional_edges("router", lambda state: route_query(state), {
"rag": "rag",
"summary": "summary"
})
builder.add_edge("rag", END)
builder.add_edge("summary", END)
multi_agent_graph = builder.compile(checkpointer=memory)