alexandre-cameron-borges commited on
Commit
01852e2
·
verified ·
1 Parent(s): c93ee2a

Upload agent.py

Browse files
Files changed (1) hide show
  1. src/agent.py +377 -0
src/agent.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent.py
2
+
3
+ import os
4
+ import json
5
+ import requests
6
+ from typing import Any, Dict
7
+
8
+ from dotenv import load_dotenv
9
+ load_dotenv()
10
+
11
+ # ========== CONFIG ==========
12
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
13
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small")
14
+ CHROMA_DIR = os.getenv("CHROMA_DIR", "./chroma_store")
15
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
16
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
17
+
18
+ # ========== LLM & Embeddings ==========
19
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
20
+ from langchain_core.tools import tool
21
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
22
+ from langgraph.prebuilt import create_react_agent
23
+
24
+ # ========== Calculator ==========
25
+ import numexpr as ne
26
+ from pydantic import BaseModel, Field
27
+
28
+
29
+ class CalcInput(BaseModel):
30
+ expression: str = Field(
31
+ description="Expression mathématique à évaluer, ex: '3*(2+5)**2'"
32
+ )
33
+
34
+
35
+ @tool("calculator", args_schema=CalcInput)
36
+ def calculator(expression: str) -> str:
37
+ """Calculette via numexpr pour évaluer une expression mathématique."""
38
+ try:
39
+ res = ne.evaluate(expression)
40
+ return str(res.item() if hasattr(res, "item") else res)
41
+ except Exception as e:
42
+ return f"CALC_ERROR: {e}"
43
+
44
+
45
+ # ========== Tavily Search ==========
46
+ from langchain_community.tools.tavily_search import TavilySearchResults
47
+
48
+ web_search_tool = TavilySearchResults(tavily_api_key=TAVILY_API_KEY)
49
+
50
+ # ========== RAG (Chroma) ==========
51
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
52
+ from langchain_chroma import Chroma
53
+ from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader
54
+
55
+ embeddings = OpenAIEmbeddings(model=EMBED_MODEL, api_key=OPENAI_API_KEY)
56
+ vectorstore = Chroma(
57
+ collection_name="rag_collection",
58
+ embedding_function=embeddings,
59
+ persist_directory=CHROMA_DIR,
60
+ )
61
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
62
+
63
+ RAG_INITIALIZED = False
64
+
65
+
66
+ def download(url: str) -> str:
67
+ """Télécharge un fichier depuis une URL et le stocke dans ./downloaded_docs."""
68
+ os.makedirs("./downloaded_docs", exist_ok=True)
69
+ path = "./downloaded_docs/" + url.split("/")[-1]
70
+ r = requests.get(url)
71
+ r.raise_for_status()
72
+ with open(path, "wb") as f:
73
+ f.write(r.content)
74
+ print("Downloaded", path)
75
+ return path
76
+
77
+
78
+ def ingest_file(path: str) -> int:
79
+ """Ingestion d’un fichier (PDF/CSV/TXT) dans le vector store Chroma."""
80
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=120)
81
+
82
+ if path.lower().endswith(".pdf"):
83
+ docs = PyPDFLoader(path).load()
84
+ elif path.lower().endswith(".csv"):
85
+ docs = CSVLoader(path).load()
86
+ else:
87
+ docs = TextLoader(path, encoding="utf-8").load()
88
+
89
+ chunks = splitter.split_documents(docs)
90
+ vectorstore.add_documents(chunks)
91
+ print(f"Ingested {len(chunks)} chunks from {path}")
92
+ return len(chunks)
93
+
94
+
95
+ def init_rag():
96
+ """Initialise le RAG (télécharge + ingère le PDF) une seule fois."""
97
+ global RAG_INITIALIZED
98
+ if RAG_INITIALIZED:
99
+ return
100
+ url = (
101
+ "https://raw.githubusercontent.com/Projet-MLOps-Team/Projet_MLOps-GenAI/main/conditions-tarifaires-particuliers-2025.pdf"
102
+ )
103
+ path = download(url)
104
+ ingest_file(path)
105
+ RAG_INITIALIZED = True
106
+ print("✅ RAG initialisé (PDF conditions tarifaires ingéré)")
107
+
108
+
109
+ class RagInput(BaseModel):
110
+ query: str = Field(description="Question en langage naturel.")
111
+ k: int = Field(
112
+ default=5,
113
+ ge=1,
114
+ le=20,
115
+ description="Nombre maximum de passages RAG à renvoyer.",
116
+ )
117
+
118
+
119
+ @tool("rag_search", args_schema=RagInput)
120
+ def rag_search(query: str, k: int = 5) -> str:
121
+ """Recherche des passages pertinents dans le vector store Chroma (RAG)."""
122
+ try:
123
+ docs = retriever.invoke(query)
124
+ if not docs:
125
+ return f"RAG_EMPTY: Aucun document trouvé pour la requête: {query}"
126
+
127
+ docs = docs[:k]
128
+ lines = [f"RAG_HITS: {len(docs)} résultats pour: {query}"]
129
+ for i, d in enumerate(docs, 1):
130
+ meta = d.metadata or {}
131
+ src = meta.get("source") or meta.get("file_path") or "unknown"
132
+ page = meta.get("page", "?")
133
+ txt = d.page_content.replace("\n", " ")
134
+ if len(txt) > 600:
135
+ txt = txt[:600] + "…"
136
+ lines.append(f"[{i}] (page {page}) {src}: {txt}")
137
+ return "\n".join(lines)
138
+ except Exception as e:
139
+ return f"RAG_ERROR: {e}"
140
+
141
+
142
+ # ========== ML Prediction Tool (remote .pkl on S3) ==========
143
+ import pandas as pd
144
+ import joblib
145
+ from io import BytesIO
146
+
147
+
148
+ class MLPredictInput(BaseModel):
149
+ payload: Dict[str, Any] = Field(
150
+ description="Dictionnaire de features pour la prédiction ML."
151
+ )
152
+
153
+
154
+ MODEL_URL = "https://mlopsgenaiapp.s3.eu-west-3.amazonaws.com/best_model.pkl"
155
+ remote_model = None
156
+
157
+
158
+ def load_remote_model(url: str):
159
+ """Télécharge un modèle pickle distant et le charge en mémoire."""
160
+ print(f"📡 Téléchargement du modèle distant : {url}")
161
+ resp = requests.get(url)
162
+ resp.raise_for_status()
163
+ buffer = BytesIO(resp.content)
164
+ model = joblib.load(buffer)
165
+ print("✅ Modèle distant chargé en mémoire")
166
+ return model
167
+
168
+
169
+ try:
170
+ remote_model = load_remote_model(MODEL_URL)
171
+ except Exception as e:
172
+ print(f"❌ ERREUR chargement modèle distant : {e}")
173
+ remote_model = None
174
+
175
+
176
+ def _align_features(df: pd.DataFrame):
177
+ """Aligne l'ordre et le set de features avec ceux utilisés au fit."""
178
+ feature_names = getattr(remote_model, "feature_names_in_", None)
179
+ if feature_names is None:
180
+ return df
181
+
182
+ missing = [f for f in feature_names if f not in df.columns]
183
+ if missing:
184
+ raise ValueError(
185
+ f"Features manquantes pour le modèle : {missing}. "
186
+ f"Features reçues : {list(df.columns)}"
187
+ )
188
+
189
+ return df[list(feature_names)]
190
+
191
+
192
+ def _predict_remote(features: Dict[str, Any]) -> Dict[str, Any]:
193
+ """Prédiction via modèle .pkl chargé depuis S3, avec sortie enrichie."""
194
+ if remote_model is None:
195
+ raise RuntimeError("Modèle distant non chargé.")
196
+
197
+ df = pd.DataFrame([features])
198
+ df = _align_features(df)
199
+
200
+ y_pred = remote_model.predict(df)[0]
201
+
202
+ proba_default = None
203
+ if hasattr(remote_model, "predict_proba"):
204
+ proba_default = float(remote_model.predict_proba(df)[0, 1])
205
+
206
+ if int(y_pred) == 1:
207
+ label_name = "Défaut probable"
208
+ else:
209
+ label_name = "Client plutôt sain"
210
+
211
+ risk_level = None
212
+ if proba_default is not None:
213
+ if proba_default < 0.20:
214
+ risk_level = "faible"
215
+ elif proba_default < 0.50:
216
+ risk_level = "modéré"
217
+ else:
218
+ risk_level = "élevé"
219
+
220
+ if proba_default is not None and risk_level is not None:
221
+ explanation = (
222
+ f"Le modèle estime une probabilité de défaut d’environ "
223
+ f"{proba_default*100:.1f} %, ce qui correspond à un risque {risk_level}."
224
+ )
225
+ else:
226
+ explanation = (
227
+ "Le modèle ne fournit pas de probabilité explicite, seulement une classe prédite."
228
+ )
229
+
230
+ return {
231
+ "label": int(y_pred),
232
+ "label_name": label_name,
233
+ "proba_default": proba_default,
234
+ "risk_level": risk_level,
235
+ "explanation": explanation,
236
+ "features_used": list(df.columns),
237
+ }
238
+
239
+
240
+ def _jsonable(x: Any) -> Any:
241
+ """Conversion best-effort en objet JSON-serialisable."""
242
+ try:
243
+ json.dumps(x)
244
+ return x
245
+ except TypeError:
246
+ if hasattr(x, "tolist"):
247
+ return x.tolist()
248
+ return str(x)
249
+
250
+
251
+ @tool("ml_predict", args_schema=MLPredictInput)
252
+ def ml_predict(payload: Dict[str, Any]) -> str:
253
+ """Effectue une prédiction via un modèle .pkl hébergé sur S3, avec sortie enrichie."""
254
+ try:
255
+ result = _predict_remote(payload)
256
+ pretty = {
257
+ "kind": "remote_pickle",
258
+ "prediction": _jsonable(result),
259
+ }
260
+ return json.dumps(pretty, ensure_ascii=False, indent=2)
261
+ except Exception as e:
262
+ return f"ML_ERROR: {e}"
263
+
264
+
265
+ # ========== SYSTEM PROMPT (texte) ==========
266
+ SYSTEM_PROMPT_TEXT = """
267
+ Tu es un assistant bancaire expert en défaut de crédit et conditions tarifaires 2025, doté d’une mémoire contextuelle
268
+ et de plusieurs outils spécialisés. Ton rôle est de sélectionner automatiquement l’outil pertinent,
269
+ d'utiliser intelligemment la mémoire issue du RAG, et de produire une réponse synthétique, fiable et systématique.
270
+
271
+ [ MÉMOIRE ]
272
+ - Considère le contenu indexé dans le RAG comme ta mémoire fiable pour les tarifs bancaires.
273
+ - Consulte systématiquement `rag_search` pour toute requête liée à : tarifs, frais, commissions, comptes, cartes,
274
+ packages, virements, incidents, clientèle (résident / non résident, jeune, premium, etc.).
275
+ - Ne JAMAIS inventer de montant : si les documents ne contiennent pas l’information, dis-le explicitement.
276
+
277
+ [ CHOIX DES OUTILS ]
278
+ 1) RAG (`rag_search`) – PRIORITAIRE :
279
+ - Utilise-le quand la question concerne des tarifs, frais, conditions, offres, segments de clientèle.
280
+ - Formule une requête courte, précise, en français (ex: “tenue de compte actif non résident”).
281
+
282
+ 2) Web Search (`web_search_tool`) :
283
+ - Utilise-le pour les actualités, contexte macro, informations externes non présentes dans les documents.
284
+ - Ne pas l’utiliser pour confirmer un chiffre qui devrait venir du PDF.
285
+
286
+ 3) ML Prediction (`ml_predict`) :
287
+ - Utilise-le si l’utilisateur demande une estimation de risque crédit ou une prédiction à partir de features.
288
+ - Transmets fidèlement les features fournies et explique le résultat (classe, probabilité, niveau de risque).
289
+
290
+ 4) Calculator (`calculator`) :
291
+ - Utilise-le pour les calculs mathématiques explicites (montants, pourcentages, ratios).
292
+
293
+ [ COMPORTEMENT ]
294
+ - Si la question peut utiliser plusieurs outils, privilégie d’abord `rag_search`.
295
+ - Si `rag_search` renvoie RAG_EMPTY ou RAG_ERROR, explique que l’info n’est pas dans les documents et n’invente rien.
296
+ - Si aucun outil n’est pertinent, demande une clarification courte ou réponds avec ce que tu peux déduire sans halluciner.
297
+
298
+ [ STYLE ]
299
+ - Toujours en français.
300
+ - Réponses claires, concises, structurées.
301
+ - Pour les tarifs, privilégie un tableau (type de compte | client | montant | périodicité) + une courte synthèse.
302
+ """.strip()
303
+
304
+
305
+ # ========== Agent factory ==========
306
+ def build_agent():
307
+ """Construit l’agent ReAct avec les tools calcul, RAG, web et ML."""
308
+ init_rag()
309
+
310
+ llm = ChatOpenAI(
311
+ model=OPENAI_MODEL,
312
+ api_key=OPENAI_API_KEY,
313
+ temperature=0,
314
+ )
315
+ tools = [calculator, rag_search, web_search_tool]
316
+ if remote_model is not None:
317
+ tools.append(ml_predict)
318
+
319
+ # Prompt compatible avec create_react_agent (version récente) :
320
+ prompt = ChatPromptTemplate.from_messages(
321
+ [
322
+ ("system", SYSTEM_PROMPT_TEXT),
323
+ MessagesPlaceholder("messages"),
324
+ ]
325
+ )
326
+
327
+ return create_react_agent(
328
+ llm,
329
+ tools,
330
+ prompt=prompt,
331
+ )
332
+
333
+
334
+ def chat(agent, messages: list, recursion_limit: int = 40) -> str:
335
+ """
336
+ messages = liste de dicts {"role": "user"/"assistant", "content": "..."}
337
+ On convertit au format attendu par LangGraph: [("user", "..."), ("assistant", "..."), ...]
338
+ """
339
+ try:
340
+ lc_messages = [(m["role"], m["content"]) for m in messages]
341
+ out = agent.invoke(
342
+ {"messages": lc_messages},
343
+ config={"recursion_limit": recursion_limit},
344
+ )
345
+ return out["messages"][-1].content
346
+ except Exception as e:
347
+ return f"AGENT_ERROR: {e}"
348
+
349
+
350
+
351
+ # ========== MAIN ==========
352
+ if __name__ == "__main__":
353
+ print("Bootstrapping agent...")
354
+
355
+ agent = build_agent()
356
+
357
+ print("\n[Calc]")
358
+ print(chat(agent, "Calcule 3*(2+5)**2 et explique en une ligne."))
359
+
360
+ print("\n[RAG]")
361
+ print(
362
+ chat(
363
+ agent,
364
+ "Résume-moi les frais de tenue de compte pour un non résident en utilisant ton outil rag_search.",
365
+ )
366
+ )
367
+
368
+ print("\n[ML]")
369
+ print(
370
+ chat(
371
+ agent,
372
+ "Appelle ml_predict avec "
373
+ "{'credit_lines_outstanding': 5, 'loan_amt_outstanding': 15000, "
374
+ "'total_debt_outstanding': 25000, 'income': 60000, 'years_employed': 10, "
375
+ "'fico_score': 720, 'debt_ratio': 0.3} et explique le résultat.",
376
+ )
377
+ )