File size: 3,290 Bytes
3107242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from langgraph.graph import StateGraph, END
from typing import TypedDict, List
from langchain_core.messages import BaseMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from src.mcp.client import summary_agent
from src.agents.rag_agent import rag_graph, invocation_state
from src.models.llm_wrapper import GeminiWrapper
import re
from src.utils.helpers import clean_agent_output, load_prompt_template
from src.configs.config import LOG_DIR
import logging
import os


# LOGGING SETUP
LOG_FILE = os.path.join(LOG_DIR, "Agents.log")
logging.basicConfig(
    filename=LOG_FILE,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)


llm = GeminiWrapper()
memory = MemorySaver()

# STATE
class MessagesState(TypedDict):
    messages: List[BaseMessage]


# ROUTER
def route_query(state: dict) -> str:
    query = state["messages"][-1].content.lower()

    '''# Déterminer le type d'invocation
    if "api" in query:  # Juste un exemple, adapte selon tes besoins
        invocation_state.invocation_type = "api"
    else:
        invocation_state.invocation_type = "chatbot"'''

    # Inline classification prompt
    prompt = f"""
    You are a query router. Classify the following user query into one of two categories:
    - "rag": if the query requires retrieving factual or external information.
    - "summary": if the query asks to summarize document or youtube session from previous messages.

    Respond with only one word: either "rag" or "summary".

    User query: "{query}"
    """

    try:
        response = llm.generate(prompt)
    except Exception as e:
        logging.error(f"Error generating routing response: {e}")
        return "rag"  # fallback routing

    logging.info(f"Raw LLM response: {response!r}")
    decision = response.strip().lower().split()[0]

    logging.info(f"Routing decision for query '{query}': {decision}")
    if decision not in ("summary", "rag"):
        logging.warning(f"Unexpected routing result '{decision}', defaulting to 'rag'")
        return "rag"

    return decision

# MCP NODE
async def mcp_node(state: MessagesState) -> dict:
    user_message = state["messages"][-1].content
    logging.info(f"MCP agent input: {user_message}")
    try:
        raw_result = await summary_agent.run(user_message)
        logging.info(f"MCP raw output: {raw_result}")
        clean_result = clean_agent_output(raw_result)
        logging.info(f"MCP cleaned output: {clean_result}")
        return {"messages": [AIMessage(content=clean_result)]}
    except Exception as e:
        logging.error(f"Error in MCP agent: {e}")
        return {"messages": [AIMessage(content="Une erreur est survenue avec l’agent MCP.")]}
    
# GRAPH DEFINITION
builder = StateGraph(MessagesState)

builder.add_node("router", lambda state: {"messages": state["messages"]})  # Dummy pass-through node
builder.add_node("rag", rag_graph)        # RAG graph
builder.add_node("summary", mcp_node)         # MCP node

builder.set_entry_point("router")
builder.add_conditional_edges("router", lambda state: route_query(state), {
    "rag": "rag",
    "summary": "summary"
})
builder.add_edge("rag", END)
builder.add_edge("summary", END)

multi_agent_graph = builder.compile(checkpointer=memory)