from langgraph.graph import StateGraph, END from typing import TypedDict, List from langchain_core.messages import BaseMessage, AIMessage from langgraph.checkpoint.memory import MemorySaver from src.mcp.client import summary_agent from src.agents.rag_agent import rag_graph, invocation_state from src.models.llm_wrapper import GeminiWrapper import re from src.utils.helpers import clean_agent_output, load_prompt_template from src.configs.config import LOG_DIR import logging import os # LOGGING SETUP LOG_FILE = os.path.join(LOG_DIR, "Agents.log") logging.basicConfig( filename=LOG_FILE, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) llm = GeminiWrapper() memory = MemorySaver() # STATE class MessagesState(TypedDict): messages: List[BaseMessage] # ROUTER def route_query(state: dict) -> str: query = state["messages"][-1].content.lower() '''# Déterminer le type d'invocation if "api" in query: # Juste un exemple, adapte selon tes besoins invocation_state.invocation_type = "api" else: invocation_state.invocation_type = "chatbot"''' # Inline classification prompt prompt = f""" You are a query router. Classify the following user query into one of two categories: - "rag": if the query requires retrieving factual or external information. - "summary": if the query asks to summarize document or youtube session from previous messages. Respond with only one word: either "rag" or "summary". User query: "{query}" """ try: response = llm.generate(prompt) except Exception as e: logging.error(f"Error generating routing response: {e}") return "rag" # fallback routing logging.info(f"Raw LLM response: {response!r}") decision = response.strip().lower().split()[0] logging.info(f"Routing decision for query '{query}': {decision}") if decision not in ("summary", "rag"): logging.warning(f"Unexpected routing result '{decision}', defaulting to 'rag'") return "rag" return decision # MCP NODE async def mcp_node(state: MessagesState) -> dict: user_message = state["messages"][-1].content logging.info(f"MCP agent input: {user_message}") try: raw_result = await summary_agent.run(user_message) logging.info(f"MCP raw output: {raw_result}") clean_result = clean_agent_output(raw_result) logging.info(f"MCP cleaned output: {clean_result}") return {"messages": [AIMessage(content=clean_result)]} except Exception as e: logging.error(f"Error in MCP agent: {e}") return {"messages": [AIMessage(content="Une erreur est survenue avec l’agent MCP.")]} # GRAPH DEFINITION builder = StateGraph(MessagesState) builder.add_node("router", lambda state: {"messages": state["messages"]}) # Dummy pass-through node builder.add_node("rag", rag_graph) # RAG graph builder.add_node("summary", mcp_node) # MCP node builder.set_entry_point("router") builder.add_conditional_edges("router", lambda state: route_query(state), { "rag": "rag", "summary": "summary" }) builder.add_edge("rag", END) builder.add_edge("summary", END) multi_agent_graph = builder.compile(checkpointer=memory)