File size: 1,878 Bytes
7cb1544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Agent SQL ReAct pour le projet NBA Analyst AI.
  - Connexion PostgreSQL via SQLDatabaseToolkit
  - Agent ReAct (LangGraph prebuilt) pour raisonner sur les données tabulaires NBA
"""

import sys, os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from langgraph.prebuilt import create_react_agent

from langchain_mistralai import ChatMistralAI
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit

from utils.config import (
    MISTRAL_API_KEY, MODEL_NAME, TEMPERATURE, TOP_P, PG_URL_READONLY, SQL_SYSTEM_PROMPT
)


def build_sql_agent():
    """
    Crée et retourne l'agent SQL ReAct.
    La connexion PostgreSQL n'est établie qu'à l'appel de cette fonction
    (connexion paresseuse, évite les erreurs à l'import si la BDD est éteinte).
    """
    # -----------------------------
    # LLM dédié aux requêtes SQL
    # -----------------------------
    sql_llm = ChatMistralAI(
        api_key=MISTRAL_API_KEY,
        model=MODEL_NAME,
        #top_p=TOP_P,
        temperature=TEMPERATURE,
    )

    # -----------------------------
    # Connexion PostgreSQL + outillage
    # -----------------------------
    db          = SQLDatabase.from_uri(PG_URL_READONLY)
    sql_toolkit = SQLDatabaseToolkit(db=db, llm=sql_llm)
    sql_tools   = sql_toolkit.get_tools()

    # -----------------------------
    # Prompt système de l'agent SQL
    # -----------------------------
    sql_system_prompt = SQL_SYSTEM_PROMPT.format(dialect=db.dialect)

    # -----------------------------
    # Agent ReAct SQL
    # -----------------------------
    # Raisonne en plusieurs étapes (ReAct) pour répondre aux questions tabulaires
    sql_agent = create_react_agent(
        sql_llm,
        sql_tools,
        prompt=sql_system_prompt,
    )

    return sql_agent