Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, END | |
| from typing import TypedDict, List | |
| from langchain_core.messages import BaseMessage, AIMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from src.mcp.client import summary_agent | |
| from src.agents.rag_agent import rag_graph, invocation_state | |
| from src.models.llm_wrapper import GeminiWrapper | |
| import re | |
| from src.utils.helpers import clean_agent_output, load_prompt_template | |
| from src.configs.config import LOG_DIR | |
| import logging | |
| import os | |
| # LOGGING SETUP | |
| LOG_FILE = os.path.join(LOG_DIR, "Agents.log") | |
| logging.basicConfig( | |
| filename=LOG_FILE, | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| llm = GeminiWrapper() | |
| memory = MemorySaver() | |
| # STATE | |
| class MessagesState(TypedDict): | |
| messages: List[BaseMessage] | |
| # ROUTER | |
| def route_query(state: dict) -> str: | |
| query = state["messages"][-1].content.lower() | |
| '''# Déterminer le type d'invocation | |
| if "api" in query: # Juste un exemple, adapte selon tes besoins | |
| invocation_state.invocation_type = "api" | |
| else: | |
| invocation_state.invocation_type = "chatbot"''' | |
| # Inline classification prompt | |
| prompt = f""" | |
| You are a query router. Classify the following user query into one of two categories: | |
| - "rag": if the query requires retrieving factual or external information. | |
| - "summary": if the query asks to summarize document or youtube session from previous messages. | |
| Respond with only one word: either "rag" or "summary". | |
| User query: "{query}" | |
| """ | |
| try: | |
| response = llm.generate(prompt) | |
| except Exception as e: | |
| logging.error(f"Error generating routing response: {e}") | |
| return "rag" # fallback routing | |
| logging.info(f"Raw LLM response: {response!r}") | |
| decision = response.strip().lower().split()[0] | |
| logging.info(f"Routing decision for query '{query}': {decision}") | |
| if decision not in ("summary", "rag"): | |
| logging.warning(f"Unexpected routing result '{decision}', defaulting to 'rag'") | |
| return "rag" | |
| return decision | |
| # MCP NODE | |
| async def mcp_node(state: MessagesState) -> dict: | |
| user_message = state["messages"][-1].content | |
| logging.info(f"MCP agent input: {user_message}") | |
| try: | |
| raw_result = await summary_agent.run(user_message) | |
| logging.info(f"MCP raw output: {raw_result}") | |
| clean_result = clean_agent_output(raw_result) | |
| logging.info(f"MCP cleaned output: {clean_result}") | |
| return {"messages": [AIMessage(content=clean_result)]} | |
| except Exception as e: | |
| logging.error(f"Error in MCP agent: {e}") | |
| return {"messages": [AIMessage(content="Une erreur est survenue avec l’agent MCP.")]} | |
| # GRAPH DEFINITION | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("router", lambda state: {"messages": state["messages"]}) # Dummy pass-through node | |
| builder.add_node("rag", rag_graph) # RAG graph | |
| builder.add_node("summary", mcp_node) # MCP node | |
| builder.set_entry_point("router") | |
| builder.add_conditional_edges("router", lambda state: route_query(state), { | |
| "rag": "rag", | |
| "summary": "summary" | |
| }) | |
| builder.add_edge("rag", END) | |
| builder.add_edge("summary", END) | |
| multi_agent_graph = builder.compile(checkpointer=memory) | |