# utils/langgraph_app.py # Graphe LangGraph hybride RAG + SQL, sans dépendance Streamlit. # Importable à la fois par MistralChat.py et evaluate_ragas.py. import time import logfire from typing import Literal, TypedDict, List from pydantic import BaseModel, Field from langgraph.graph import StateGraph, START, END from langchain_core.messages import SystemMessage, HumanMessage from langchain_mistralai import ChatMistralAI from utils.config import ( MISTRAL_API_KEY, MODEL_NAME, TEMPERATURE, TOP_P, SEARCH_K, RAG_SYSTEM_PROMT, LLM_CALL_DELAY, ) from utils.sql_tool import build_sql_agent # ============================== # État du graphe # ============================== class AppState(TypedDict, total=False): user_question: str route: Literal["rag", "sql"] rag_contexts: List[str] # Chunks FAISS récupérés (vide si route SQL) rag_answer: str sql_answer: str final_answer: str # ============================== # Schéma de décision du routeur # ============================== class RouteDecision(BaseModel): """Valide que la destination ne peut être que 'rag' ou 'sql'.""" destination: Literal["rag", "sql"] = Field(...) # ============================== # Construction du graphe # ============================== def build_graph(vector_store_manager): """ Construit et retourne (graph, llm). Args: vector_store_manager : instance de VectorStoreManager (peut être None si le Vector Store n'est pas disponible — le nœud RAG retournera alors un message d'erreur). """ llm = ChatMistralAI( api_key=MISTRAL_API_KEY, model=MODEL_NAME, #top_p=TOP_P, temperature=TEMPERATURE, ) router = llm.with_structured_output(RouteDecision) # Connexion PostgreSQL établie ici (paresseuse), pas à l'import du module sql_agent = build_sql_agent() # --- Nœud routeur --- def router_node(state: AppState) -> dict: """Classe la question en 'rag' (texte narratif) ou 'sql' (données chiffrées).""" logfire.info(f"[Router] Question : '{state['user_question']}'") time.sleep(LLM_CALL_DELAY) # Rate limiting decision = router.invoke([ SystemMessage(content=""" Tu es un routeur pour un assistant NBA. Choisis : - 'rag' si la question porte sur du contenu textuel (articles, analyses narratives, contexte, règles, actualités, discussions Reddit). - 'sql' si la question demande un calcul, un comptage, un filtrage, un classement ou toute donnée chiffrée provenant des tables de la base NBA (joueurs, équipes, points, statistiques de saison). Réponds uniquement avec la destination. """), HumanMessage(content=state["user_question"]) ]) logfire.info(f"[Router] → {decision.destination}") return {"route": decision.destination} def route_after_router(state: AppState) -> str: """Renvoie la route choisie pour orienter le graphe vers le bon nœud.""" return state["route"] # --- Nœud RAG --- def rag_node(state: AppState) -> dict: """Recherche les chunks pertinents dans FAISS, puis génère une réponse contextualisée.""" question = state["user_question"] logfire.info(f"[RAG] Recherche pour : '{question}'") if vector_store_manager is None: return { "rag_answer": "Le service de recherche documentaire n'est pas disponible.", "rag_contexts": [""], } results = vector_store_manager.search(question, k=SEARCH_K) if results: context = "\n\n---\n\n".join( f"Source : {r['metadata'].get('source', 'Inconnue')} (Score : {r['score']:.1f}%)\n{r['text']}" for r in results ) contexts_list = [r["text"] for r in results] else: context = "Aucune information pertinente trouvée dans la base documentaire." contexts_list = [""] logfire.warn("[RAG] Aucun chunk pertinent trouvé.") time.sleep(LLM_CALL_DELAY) # Rate limiting response = llm.invoke([ SystemMessage(content=RAG_SYSTEM_PROMT.format(context_str=context, question=question)) ]) logfire.info("[RAG] Réponse générée.") return {"rag_answer": response.content, "rag_contexts": contexts_list} # --- Nœud SQL --- def sql_node(state: AppState) -> dict: """Délègue la question à l'agent SQL ReAct (sql_tool.py) et récupère sa réponse finale.""" question = state["user_question"] logfire.info(f"[SQL] Requête : '{question}'") time.sleep(LLM_CALL_DELAY) # Rate limiting result = sql_agent.invoke({ "messages": [{"role": "user", "content": question}] }) final_msg = result["messages"][-1].content logfire.info("[SQL] Réponse générée.") # Pas de contexte FAISS pour la route SQL return {"sql_answer": final_msg, "rag_contexts": [""]} # --- Nœud de synthèse finale --- def finalize_node(state: AppState) -> dict: """Sélectionne la réponse RAG ou SQL selon la route empruntée.""" if state.get("route") == "rag": return {"final_answer": state.get("rag_answer", "Aucune réponse RAG disponible.")} return {"final_answer": state.get("sql_answer", "Aucune réponse SQL disponible.")} # --- Compilation du graphe --- builder = StateGraph(AppState) builder.add_node("router_node", router_node) builder.add_node("rag_node", rag_node) builder.add_node("sql_node", sql_node) builder.add_node("finalize_node", finalize_node) builder.add_edge(START, "router_node") builder.add_conditional_edges( "router_node", route_after_router, {"rag": "rag_node", "sql": "sql_node"}, ) builder.add_edge("rag_node", "finalize_node") builder.add_edge("sql_node", "finalize_node") builder.add_edge("finalize_node", END) return builder.compile(), llm