Vieuxwalo commited on
Commit
314274f
·
verified ·
1 Parent(s): 40ac0c5

Create llm_engine.py

Browse files
Files changed (1) hide show
  1. llm_engine.py +355 -0
llm_engine.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ llm_engine.py — Moteur d'inférence LLM
3
+ Gère le chargement des modèles, la génération de texte et le fallback.
4
+ """
5
+
6
+ import time
7
+ import logging
8
+ from typing import Optional, Tuple, Generator
9
+
10
+ logging.basicConfig(level=logging.INFO, format="[%(name)s] %(levelname)s: %(message)s")
11
+ logger = logging.getLogger("LLMEngine")
12
+
13
+ from config import (
14
+ LLM_MODEL, QA_MODEL, MAX_NEW_TOKENS, TEMPERATURE,
15
+ TOP_P, REPETITION_PENALTY, DO_SAMPLE
16
+ )
17
+ from utils import clean_response, is_valid_response, format_error_message
18
+ from prompts import build_chat_prompt, build_qa_context
19
+
20
+
21
+ class LLMEngine:
22
+ """
23
+ Moteur d'inférence principal avec système de fallback en cascade.
24
+
25
+ Cascade de fallback :
26
+ 1. Modèle LLM principal (génération chat)
27
+ 2. Modèle QA (question-answering sur contexte)
28
+ 3. Réponse de fallback statique
29
+
30
+ Cette architecture garantit qu'une réponse est toujours retournée,
31
+ même si les modèles principaux ne sont pas disponibles.
32
+ """
33
+
34
+ def __init__(self):
35
+ self.text_pipeline = None # Pipeline génération de texte (LLM)
36
+ self.qa_pipeline = None # Pipeline question-answering (fallback)
37
+ self.models_loaded = False
38
+ self._load_models()
39
+
40
+ def _load_models(self) -> None:
41
+ """
42
+ Charge les modèles IA de manière sécurisée.
43
+ Utilise lazy loading — ne bloque pas le démarrage si un modèle échoue.
44
+ """
45
+ logger.info(f"Chargement du modèle LLM : {LLM_MODEL}")
46
+
47
+ # Import différé pour éviter les erreurs si transformers n'est pas installé
48
+ try:
49
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
50
+ import torch
51
+
52
+ # Détection du device (GPU si disponible, sinon CPU)
53
+ device = 0 if torch.cuda.is_available() else -1
54
+ device_name = "GPU (CUDA)" if device == 0 else "CPU"
55
+ logger.info(f"Device sélectionné : {device_name}")
56
+
57
+ # ── Pipeline LLM principal ──────────────────────────────────────
58
+ try:
59
+ self.text_pipeline = pipeline(
60
+ "text-generation",
61
+ model=LLM_MODEL,
62
+ device=device,
63
+ # Paramètres d'optimisation mémoire
64
+ torch_dtype="auto", # Sélectionne float16 sur GPU
65
+ trust_remote_code=False,
66
+ )
67
+ logger.info(f"✅ LLM chargé : {LLM_MODEL}")
68
+ except Exception as e:
69
+ logger.warning(f"⚠️ Échec LLM principal : {e}")
70
+ self.text_pipeline = None
71
+
72
+ # ── Pipeline QA de fallback ─────────────────────────────────────
73
+ try:
74
+ self.qa_pipeline = pipeline(
75
+ "question-answering",
76
+ model=QA_MODEL,
77
+ device=device,
78
+ )
79
+ logger.info(f"✅ QA pipeline chargé : {QA_MODEL}")
80
+ except Exception as e:
81
+ logger.warning(f"⚠️ Échec QA pipeline : {e}")
82
+ self.qa_pipeline = None
83
+
84
+ self.models_loaded = self.text_pipeline is not None or self.qa_pipeline is not None
85
+
86
+ except ImportError:
87
+ logger.error("❌ transformers non installé. Installez avec : pip install transformers torch")
88
+ self.models_loaded = False
89
+
90
+ def generate(
91
+ self,
92
+ user_message: str,
93
+ conversation_history: list,
94
+ domain: Optional[str] = None,
95
+ max_tokens: int = MAX_NEW_TOKENS,
96
+ ) -> Tuple[str, str]:
97
+ """
98
+ Génère une réponse pour le message utilisateur.
99
+
100
+ Cascade de fallback :
101
+ 1. LLM principal → génération de texte contextualisée
102
+ 2. QA pipeline → extraction de réponse depuis contexte
103
+ 3. Message d'erreur informatif
104
+
105
+ Args:
106
+ user_message: Le message de l'utilisateur
107
+ conversation_history: Historique récent de la conversation
108
+ domain: Domaine détecté pour les prompts spécialisés
109
+ max_tokens: Nombre maximum de tokens à générer
110
+
111
+ Returns:
112
+ Tuple (réponse, source) où source ∈ {'llm', 'qa', 'fallback'}
113
+ """
114
+ start_time = time.time()
115
+
116
+ # ── Tentative 1 : LLM principal ────────────────────────────────────
117
+ if self.text_pipeline is not None:
118
+ response, source = self._generate_with_llm(
119
+ user_message, conversation_history, domain, max_tokens
120
+ )
121
+ if response:
122
+ elapsed = time.time() - start_time
123
+ logger.info(f"[LLM] Réponse générée en {elapsed:.2f}s ({source})")
124
+ return response, source
125
+
126
+ # ── Tentative 2 : QA pipeline ──────────────────────────────────────
127
+ if self.qa_pipeline is not None:
128
+ response, source = self._generate_with_qa(user_message, domain)
129
+ if response:
130
+ elapsed = time.time() - start_time
131
+ logger.info(f"[QA] Réponse extraite en {elapsed:.2f}s ({source})")
132
+ return response, source
133
+
134
+ # ── Fallback final ─────────────────────────────────────────────────
135
+ logger.warning("Tous les modèles ont échoué. Retour message de fallback.")
136
+ fallback = self._get_fallback_response(user_message, domain)
137
+ return fallback, "fallback"
138
+
139
+ def _generate_with_llm(
140
+ self,
141
+ user_message: str,
142
+ conversation_history: list,
143
+ domain: Optional[str],
144
+ max_tokens: int,
145
+ ) -> Tuple[Optional[str], str]:
146
+ """
147
+ Génère avec le LLM principal (pipeline text-generation).
148
+
149
+ Utilise le format ChatML pour structurer le prompt.
150
+ Extrait uniquement la partie 'assistant' de la sortie.
151
+
152
+ Returns:
153
+ Tuple (réponse nettoyée, 'llm') ou (None, 'llm_failed')
154
+ """
155
+ try:
156
+ # Construction du prompt formaté
157
+ prompt = build_chat_prompt(conversation_history, user_message, domain)
158
+
159
+ # Génération
160
+ outputs = self.text_pipeline(
161
+ prompt,
162
+ max_new_tokens=max_tokens,
163
+ temperature=TEMPERATURE,
164
+ top_p=TOP_P,
165
+ repetition_penalty=REPETITION_PENALTY,
166
+ do_sample=DO_SAMPLE,
167
+ return_full_text=False, # Retourne uniquement la partie générée
168
+ pad_token_id=self.text_pipeline.tokenizer.eos_token_id,
169
+ )
170
+
171
+ if not outputs or not outputs[0]:
172
+ return None, "llm_empty"
173
+
174
+ generated_text = outputs[0].get("generated_text", "")
175
+
176
+ # Extraction de la réponse : prend tout avant le prochain <|user|>
177
+ if "<|user|>" in generated_text:
178
+ generated_text = generated_text.split("<|user|>")[0]
179
+ if "<|system|>" in generated_text:
180
+ generated_text = generated_text.split("<|system|>")[0]
181
+
182
+ # Nettoyage
183
+ response = clean_response(generated_text)
184
+
185
+ if is_valid_response(response):
186
+ return response, "llm"
187
+
188
+ return None, "llm_short"
189
+
190
+ except Exception as e:
191
+ logger.error(f"Erreur génération LLM : {e}")
192
+ return None, "llm_error"
193
+
194
+ def _generate_with_qa(
195
+ self,
196
+ user_message: str,
197
+ domain: Optional[str],
198
+ ) -> Tuple[Optional[str], str]:
199
+ """
200
+ Extrait une réponse via le pipeline question-answering.
201
+
202
+ Utilise un contexte enrichi selon le domaine détecté.
203
+ Plus fiable que le LLM pour les questions factuelles courtes.
204
+
205
+ Returns:
206
+ Tuple (réponse, 'qa') ou (None, 'qa_failed')
207
+ """
208
+ try:
209
+ context = build_qa_context(domain)
210
+
211
+ result = self.qa_pipeline(
212
+ question=user_message,
213
+ context=context,
214
+ max_answer_len=256,
215
+ )
216
+
217
+ answer = result.get("answer", "").strip()
218
+ score = result.get("score", 0)
219
+
220
+ logger.info(f"[QA] Score de confiance : {score:.3f}")
221
+
222
+ # Accepter la réponse si la confiance est suffisante
223
+ if score > 0.1 and is_valid_response(answer):
224
+ return answer, "qa"
225
+
226
+ return None, "qa_low_confidence"
227
+
228
+ except Exception as e:
229
+ logger.error(f"Erreur QA pipeline : {e}")
230
+ return None, "qa_error"
231
+
232
+ def _get_fallback_response(self, user_message: str, domain: Optional[str]) -> str:
233
+ """
234
+ Génère une réponse de secours informative basée sur des règles simples.
235
+
236
+ Analyse les mots-clés de la question pour retourner une réponse
237
+ pertinente depuis un mini-dictionnaire intégré.
238
+
239
+ Args:
240
+ user_message: La question de l'utilisateur
241
+ domain: Domaine détecté
242
+
243
+ Returns:
244
+ Réponse textuelle de fallback
245
+ """
246
+ msg_lower = user_message.lower()
247
+
248
+ # Mini-base de réponses intégrées (fallback ultime)
249
+ fallback_rules = {
250
+ # Réseaux
251
+ ("switch", "commutateur"): (
252
+ "Un **switch** (commutateur) est un équipement réseau de couche 2 (OSI) "
253
+ "qui interconnecte des appareils dans un réseau local (LAN). "
254
+ "Il utilise les adresses MAC pour acheminer les trames vers le bon port. "
255
+ "Commandes Cisco de base :\n"
256
+ "```\nSwitch> enable\nSwitch# show mac address-table\nSwitch# show interfaces\n```"
257
+ ),
258
+ ("routeur", "router"): (
259
+ "Un **routeur** est un équipement réseau de couche 3 (OSI) qui interconnecte "
260
+ "plusieurs réseaux différents. Il utilise les adresses IP et une table de routage "
261
+ "pour acheminer les paquets.\n"
262
+ "Commandes Cisco de base :\n"
263
+ "```\nRouter> enable\nRouter# show ip route\nRouter# show ip interface brief\n```"
264
+ ),
265
+ ("vlan",): (
266
+ "Un **VLAN** (Virtual LAN) permet de segmenter logiquement un réseau physique "
267
+ "en plusieurs réseaux virtuels isolés. Configuration Cisco :\n"
268
+ "```\nSwitch(config)# vlan 10\nSwitch(config-vlan)# name SERVEURS\n"
269
+ "Switch(config)# interface fa0/1\nSwitch(config-if)# switchport mode access\n"
270
+ "Switch(config-if)# switchport access vlan 10\n```"
271
+ ),
272
+ ("ospf",): (
273
+ "**OSPF** (Open Shortest Path First) est un protocole de routage dynamique "
274
+ "à état de lien (Link-State). Il utilise l'algorithme de Dijkstra pour calculer "
275
+ "les meilleurs chemins. Configuration Cisco :\n"
276
+ "```\nRouter(config)# router ospf 1\n"
277
+ "Router(config-router)# network 192.168.1.0 0.0.0.255 area 0\n```"
278
+ ),
279
+ # Cybersécurité
280
+ ("vpn",): (
281
+ "Un **VPN** (Virtual Private Network) crée un tunnel chiffré entre deux points "
282
+ "sur Internet, assurant confidentialité et intégrité des données. "
283
+ "Types principaux : Site-to-Site (deux réseaux), Remote Access (nomade). "
284
+ "Protocoles : IPSec, OpenVPN, WireGuard, SSL/TLS."
285
+ ),
286
+ ("firewall", "pare-feu"): (
287
+ "Un **pare-feu** (firewall) filtre le trafic réseau selon des règles de sécurité. "
288
+ "Types : stateless (filtre par paquet), stateful (suit les connexions), "
289
+ "applicatif (inspecte le contenu - NGFW). "
290
+ "Il constitue la première ligne de défense du réseau."
291
+ ),
292
+ # IA/ML
293
+ ("machine learning", "apprentissage automatique"): (
294
+ "Le **Machine Learning** est un sous-domaine de l'IA où les algorithmes "
295
+ "apprennent automatiquement à partir de données. "
296
+ "3 types principaux :\n"
297
+ "- **Supervisé** : données étiquetées (classification, régression)\n"
298
+ "- **Non supervisé** : données non étiquetées (clustering)\n"
299
+ "- **Par renforcement** : apprentissage par récompenses"
300
+ ),
301
+ ("llm", "grand modèle de langage"): (
302
+ "Un **LLM** (Large Language Model) est un modèle de langage entraîné sur "
303
+ "d'immenses corpus textuels. Il utilise l'architecture Transformer. "
304
+ "Exemples : GPT-4, Claude, LLaMA, Mistral. "
305
+ "Ils excellent dans la génération de texte, la traduction, le code et le Q&A."
306
+ ),
307
+ }
308
+
309
+ # Cherche la première règle correspondante
310
+ for keywords, response in fallback_rules.items():
311
+ if any(kw in msg_lower for kw in keywords):
312
+ return response
313
+
314
+ # Réponse générique selon le domaine
315
+ domain_generic = {
316
+ "réseaux": (
317
+ "Je n'ai pas trouvé de réponse précise à votre question sur les réseaux. "
318
+ "Pour approfondir ce sujet, je recommande :\n"
319
+ "- La documentation Cisco (cisco.com/c/en/us/support)\n"
320
+ "- Les cours CCNA sur NetAcad (netacad.com)\n"
321
+ "- Packet Tracer pour la pratique en simulation"
322
+ ),
323
+ "cybersécurité": (
324
+ "Je n'ai pas trouvé de réponse précise à votre question en cybersécurité. "
325
+ "Ressources recommandées :\n"
326
+ "- OWASP (owasp.org) pour la sécurité applicative\n"
327
+ "- SANS Institute (sans.org) pour les formations\n"
328
+ "- TryHackMe / HackTheBox pour la pratique"
329
+ ),
330
+ "ia": (
331
+ "Je n'ai pas trouvé de réponse précise à votre question sur l'IA/ML. "
332
+ "Ressources recommandées :\n"
333
+ "- Coursera / DeepLearning.AI (Andrew Ng)\n"
334
+ "- Hugging Face (huggingface.co) pour les modèles\n"
335
+ "- Papers With Code (paperswithcode.com)"
336
+ ),
337
+ }
338
+
339
+ if domain and domain in domain_generic:
340
+ return domain_generic[domain]
341
+
342
+ return (
343
+ "Je n'ai pas pu générer une réponse complète à votre question. "
344
+ "Pourriez-vous la reformuler ou la préciser ? "
345
+ "WENDAA AI couvre les domaines : réseaux, cybersécurité, IA/ML et data."
346
+ )
347
+
348
+ def get_status(self) -> dict:
349
+ """Retourne le statut des modèles chargés."""
350
+ return {
351
+ "llm_loaded": self.text_pipeline is not None,
352
+ "qa_loaded": self.qa_pipeline is not None,
353
+ "model_name": LLM_MODEL,
354
+ "qa_model_name": QA_MODEL,
355
+ }