Spaces:
Running
Running
| # utils/langgraph_app.py | |
| # Graphe LangGraph hybride RAG + SQL, sans dépendance Streamlit. | |
| # Importable à la fois par MistralChat.py et evaluate_ragas.py. | |
| import time | |
| import logfire | |
| from typing import Literal, TypedDict, List | |
| from pydantic import BaseModel, Field | |
| from langgraph.graph import StateGraph, START, END | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_mistralai import ChatMistralAI | |
| from utils.config import ( | |
| MISTRAL_API_KEY, MODEL_NAME, TEMPERATURE, TOP_P, | |
| SEARCH_K, RAG_SYSTEM_PROMT, LLM_CALL_DELAY, | |
| ) | |
| from utils.sql_tool import build_sql_agent | |
| # ============================== | |
| # État du graphe | |
| # ============================== | |
| class AppState(TypedDict, total=False): | |
| user_question: str | |
| route: Literal["rag", "sql"] | |
| rag_contexts: List[str] # Chunks FAISS récupérés (vide si route SQL) | |
| rag_answer: str | |
| sql_answer: str | |
| final_answer: str | |
| # ============================== | |
| # Schéma de décision du routeur | |
| # ============================== | |
| class RouteDecision(BaseModel): | |
| """Valide que la destination ne peut être que 'rag' ou 'sql'.""" | |
| destination: Literal["rag", "sql"] = Field(...) | |
| # ============================== | |
| # Construction du graphe | |
| # ============================== | |
| def build_graph(vector_store_manager): | |
| """ | |
| Construit et retourne (graph, llm). | |
| Args: | |
| vector_store_manager : instance de VectorStoreManager (peut être None | |
| si le Vector Store n'est pas disponible — le nœud | |
| RAG retournera alors un message d'erreur). | |
| """ | |
| llm = ChatMistralAI( | |
| api_key=MISTRAL_API_KEY, | |
| model=MODEL_NAME, | |
| #top_p=TOP_P, | |
| temperature=TEMPERATURE, | |
| ) | |
| router = llm.with_structured_output(RouteDecision) | |
| # Connexion PostgreSQL établie ici (paresseuse), pas à l'import du module | |
| sql_agent = build_sql_agent() | |
| # --- Nœud routeur --- | |
| def router_node(state: AppState) -> dict: | |
| """Classe la question en 'rag' (texte narratif) ou 'sql' (données chiffrées).""" | |
| logfire.info(f"[Router] Question : '{state['user_question']}'") | |
| time.sleep(LLM_CALL_DELAY) # Rate limiting | |
| decision = router.invoke([ | |
| SystemMessage(content=""" | |
| Tu es un routeur pour un assistant NBA. | |
| Choisis : | |
| - 'rag' si la question porte sur du contenu textuel (articles, analyses narratives, contexte, | |
| règles, actualités, discussions Reddit). | |
| - 'sql' si la question demande un calcul, un comptage, un filtrage, un classement ou toute | |
| donnée chiffrée provenant des tables de la base NBA | |
| (joueurs, équipes, points, statistiques de saison). | |
| Réponds uniquement avec la destination. | |
| """), | |
| HumanMessage(content=state["user_question"]) | |
| ]) | |
| logfire.info(f"[Router] → {decision.destination}") | |
| return {"route": decision.destination} | |
| def route_after_router(state: AppState) -> str: | |
| """Renvoie la route choisie pour orienter le graphe vers le bon nœud.""" | |
| return state["route"] | |
| # --- Nœud RAG --- | |
| def rag_node(state: AppState) -> dict: | |
| """Recherche les chunks pertinents dans FAISS, puis génère une réponse contextualisée.""" | |
| question = state["user_question"] | |
| logfire.info(f"[RAG] Recherche pour : '{question}'") | |
| if vector_store_manager is None: | |
| return { | |
| "rag_answer": "Le service de recherche documentaire n'est pas disponible.", | |
| "rag_contexts": [""], | |
| } | |
| results = vector_store_manager.search(question, k=SEARCH_K) | |
| if results: | |
| context = "\n\n---\n\n".join( | |
| f"Source : {r['metadata'].get('source', 'Inconnue')} (Score : {r['score']:.1f}%)\n{r['text']}" | |
| for r in results | |
| ) | |
| contexts_list = [r["text"] for r in results] | |
| else: | |
| context = "Aucune information pertinente trouvée dans la base documentaire." | |
| contexts_list = [""] | |
| logfire.warn("[RAG] Aucun chunk pertinent trouvé.") | |
| time.sleep(LLM_CALL_DELAY) # Rate limiting | |
| response = llm.invoke([ | |
| SystemMessage(content=RAG_SYSTEM_PROMT.format(context_str=context, question=question)) | |
| ]) | |
| logfire.info("[RAG] Réponse générée.") | |
| return {"rag_answer": response.content, "rag_contexts": contexts_list} | |
| # --- Nœud SQL --- | |
| def sql_node(state: AppState) -> dict: | |
| """Délègue la question à l'agent SQL ReAct (sql_tool.py) et récupère sa réponse finale.""" | |
| question = state["user_question"] | |
| logfire.info(f"[SQL] Requête : '{question}'") | |
| time.sleep(LLM_CALL_DELAY) # Rate limiting | |
| result = sql_agent.invoke({ | |
| "messages": [{"role": "user", "content": question}] | |
| }) | |
| final_msg = result["messages"][-1].content | |
| logfire.info("[SQL] Réponse générée.") | |
| # Pas de contexte FAISS pour la route SQL | |
| return {"sql_answer": final_msg, "rag_contexts": [""]} | |
| # --- Nœud de synthèse finale --- | |
| def finalize_node(state: AppState) -> dict: | |
| """Sélectionne la réponse RAG ou SQL selon la route empruntée.""" | |
| if state.get("route") == "rag": | |
| return {"final_answer": state.get("rag_answer", "Aucune réponse RAG disponible.")} | |
| return {"final_answer": state.get("sql_answer", "Aucune réponse SQL disponible.")} | |
| # --- Compilation du graphe --- | |
| builder = StateGraph(AppState) | |
| builder.add_node("router_node", router_node) | |
| builder.add_node("rag_node", rag_node) | |
| builder.add_node("sql_node", sql_node) | |
| builder.add_node("finalize_node", finalize_node) | |
| builder.add_edge(START, "router_node") | |
| builder.add_conditional_edges( | |
| "router_node", | |
| route_after_router, | |
| {"rag": "rag_node", "sql": "sql_node"}, | |
| ) | |
| builder.add_edge("rag_node", "finalize_node") | |
| builder.add_edge("sql_node", "finalize_node") | |
| builder.add_edge("finalize_node", END) | |
| return builder.compile(), llm | |