oc_mlops_projet_3 / utils /langgraph_app.py
CedM's picture
Déploiement automatique depuis GitLab CI
7cb1544 verified
# utils/langgraph_app.py
# Graphe LangGraph hybride RAG + SQL, sans dépendance Streamlit.
# Importable à la fois par MistralChat.py et evaluate_ragas.py.
import time
import logfire
from typing import Literal, TypedDict, List
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_mistralai import ChatMistralAI
from utils.config import (
MISTRAL_API_KEY, MODEL_NAME, TEMPERATURE, TOP_P,
SEARCH_K, RAG_SYSTEM_PROMT, LLM_CALL_DELAY,
)
from utils.sql_tool import build_sql_agent
# ==============================
# État du graphe
# ==============================
class AppState(TypedDict, total=False):
user_question: str
route: Literal["rag", "sql"]
rag_contexts: List[str] # Chunks FAISS récupérés (vide si route SQL)
rag_answer: str
sql_answer: str
final_answer: str
# ==============================
# Schéma de décision du routeur
# ==============================
class RouteDecision(BaseModel):
"""Valide que la destination ne peut être que 'rag' ou 'sql'."""
destination: Literal["rag", "sql"] = Field(...)
# ==============================
# Construction du graphe
# ==============================
def build_graph(vector_store_manager):
"""
Construit et retourne (graph, llm).
Args:
vector_store_manager : instance de VectorStoreManager (peut être None
si le Vector Store n'est pas disponible — le nœud
RAG retournera alors un message d'erreur).
"""
llm = ChatMistralAI(
api_key=MISTRAL_API_KEY,
model=MODEL_NAME,
#top_p=TOP_P,
temperature=TEMPERATURE,
)
router = llm.with_structured_output(RouteDecision)
# Connexion PostgreSQL établie ici (paresseuse), pas à l'import du module
sql_agent = build_sql_agent()
# --- Nœud routeur ---
def router_node(state: AppState) -> dict:
"""Classe la question en 'rag' (texte narratif) ou 'sql' (données chiffrées)."""
logfire.info(f"[Router] Question : '{state['user_question']}'")
time.sleep(LLM_CALL_DELAY) # Rate limiting
decision = router.invoke([
SystemMessage(content="""
Tu es un routeur pour un assistant NBA.
Choisis :
- 'rag' si la question porte sur du contenu textuel (articles, analyses narratives, contexte,
règles, actualités, discussions Reddit).
- 'sql' si la question demande un calcul, un comptage, un filtrage, un classement ou toute
donnée chiffrée provenant des tables de la base NBA
(joueurs, équipes, points, statistiques de saison).
Réponds uniquement avec la destination.
"""),
HumanMessage(content=state["user_question"])
])
logfire.info(f"[Router] → {decision.destination}")
return {"route": decision.destination}
def route_after_router(state: AppState) -> str:
"""Renvoie la route choisie pour orienter le graphe vers le bon nœud."""
return state["route"]
# --- Nœud RAG ---
def rag_node(state: AppState) -> dict:
"""Recherche les chunks pertinents dans FAISS, puis génère une réponse contextualisée."""
question = state["user_question"]
logfire.info(f"[RAG] Recherche pour : '{question}'")
if vector_store_manager is None:
return {
"rag_answer": "Le service de recherche documentaire n'est pas disponible.",
"rag_contexts": [""],
}
results = vector_store_manager.search(question, k=SEARCH_K)
if results:
context = "\n\n---\n\n".join(
f"Source : {r['metadata'].get('source', 'Inconnue')} (Score : {r['score']:.1f}%)\n{r['text']}"
for r in results
)
contexts_list = [r["text"] for r in results]
else:
context = "Aucune information pertinente trouvée dans la base documentaire."
contexts_list = [""]
logfire.warn("[RAG] Aucun chunk pertinent trouvé.")
time.sleep(LLM_CALL_DELAY) # Rate limiting
response = llm.invoke([
SystemMessage(content=RAG_SYSTEM_PROMT.format(context_str=context, question=question))
])
logfire.info("[RAG] Réponse générée.")
return {"rag_answer": response.content, "rag_contexts": contexts_list}
# --- Nœud SQL ---
def sql_node(state: AppState) -> dict:
"""Délègue la question à l'agent SQL ReAct (sql_tool.py) et récupère sa réponse finale."""
question = state["user_question"]
logfire.info(f"[SQL] Requête : '{question}'")
time.sleep(LLM_CALL_DELAY) # Rate limiting
result = sql_agent.invoke({
"messages": [{"role": "user", "content": question}]
})
final_msg = result["messages"][-1].content
logfire.info("[SQL] Réponse générée.")
# Pas de contexte FAISS pour la route SQL
return {"sql_answer": final_msg, "rag_contexts": [""]}
# --- Nœud de synthèse finale ---
def finalize_node(state: AppState) -> dict:
"""Sélectionne la réponse RAG ou SQL selon la route empruntée."""
if state.get("route") == "rag":
return {"final_answer": state.get("rag_answer", "Aucune réponse RAG disponible.")}
return {"final_answer": state.get("sql_answer", "Aucune réponse SQL disponible.")}
# --- Compilation du graphe ---
builder = StateGraph(AppState)
builder.add_node("router_node", router_node)
builder.add_node("rag_node", rag_node)
builder.add_node("sql_node", sql_node)
builder.add_node("finalize_node", finalize_node)
builder.add_edge(START, "router_node")
builder.add_conditional_edges(
"router_node",
route_after_router,
{"rag": "rag_node", "sql": "sql_node"},
)
builder.add_edge("rag_node", "finalize_node")
builder.add_edge("sql_node", "finalize_node")
builder.add_edge("finalize_node", END)
return builder.compile(), llm