File size: 20,770 Bytes
7cb1544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# utils/vector_store.py
import os
import time
import pickle
import faiss
import numpy as np
import logfire
from typing import List, Dict, Optional, Any
from mistralai import Mistral, SDKError
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document # Utilisé pour le format attendu par le splitter
from pydantic import BaseModel, Field, field_validator

from .config import (
    MISTRAL_API_KEY, EMBEDDING_MODEL, EMBEDDING_BATCH_SIZE,
    FAISS_INDEX_FILE, DOCUMENT_CHUNKS_FILE, CHUNK_SIZE, CHUNK_OVERLAP,
    TEMPERATURE, TOP_P
)


# ---------------------------
# Schemas Pydantic V2
# ---------------------------
class ChunkIn(BaseModel):
    """Représente les données d'entrée pour le découpage d'un document en chunks.
    Correspond à la structure des documents chargés (page_content + metadata)."""
    page_content: str = Field(..., min_length=1, description="Contenu textuel du document")
    metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées du document (filename, source, category...)")

    @field_validator('page_content')
    @classmethod
    def text_not_empty(cls, v):
        if not v or not v.strip():
            raise ValueError('Le texte du document ne peut pas être vide')
        return v.strip()


class EmbeddingIn(BaseModel):
    """Représente un chunk découpé prêt pour la génération d'embeddings.
    Correspond à la structure des dicts produits par _split_documents_to_chunks."""
    id: str = Field(..., min_length=1, description="Identifiant unique du chunk (doc_index_chunk_index)")
    text: str = Field(..., min_length=1, description="Texte du chunk")
    metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées héritées + chunk_id_in_doc + start_index")

    @field_validator('text')
    @classmethod
    def chunk_text_not_empty(cls, v):
        if not v or not v.strip():
            raise ValueError('Le texte du chunk ne peut pas être vide')
        return v.strip()


class VectordbIn(BaseModel):
    """Représente les données d'entrée pour l'ajout d'un chunk et de son embedding dans l'index Faiss."""
    chunk_id: str = Field(..., min_length=1, description="Identifiant du chunk correspondant (clé 'id' du chunk)")
    vector: List[float] = Field(..., min_length=1, description="Vecteur d'embedding")

class VectorStoreManager:
    """Gère la création, le chargement et la recherche dans un index Faiss."""

    def __init__(self):
        self.index: Optional[faiss.Index] = None
        self.document_chunks: List[Dict[str, any]] = []
        self.mistral_client = Mistral(api_key=MISTRAL_API_KEY)
        self._load_index_and_chunks()

    def _load_index_and_chunks(self):
        """Charge l'index Faiss et les chunks si les fichiers existent."""
        if os.path.exists(FAISS_INDEX_FILE) and os.path.exists(DOCUMENT_CHUNKS_FILE):
            try:
                logfire.info(f"Chargement de l'index Faiss depuis {FAISS_INDEX_FILE}...")
                self.index = faiss.read_index(FAISS_INDEX_FILE)
                logfire.info(f"Chargement des chunks depuis {DOCUMENT_CHUNKS_FILE}...")
                with open(DOCUMENT_CHUNKS_FILE, 'rb') as f:
                    self.document_chunks = pickle.load(f)
                logfire.info(f"Index ({self.index.ntotal} vecteurs) et {len(self.document_chunks)} chunks chargés.")
            except Exception as e:
                logfire.error(f"Erreur lors du chargement de l'index/chunks: {e}")
                self.index = None
                self.document_chunks = []
        else:
            logfire.warn("Fichiers d'index Faiss ou de chunks non trouvés. L'index est vide.")

    def _split_documents_to_chunks(self, documents: List[Dict[str, any]]) -> List[Dict[str, any]]:
        """Découpe les documents en chunks avec métadonnées."""
        logfire.info(f"Découpage de {len(documents)} documents/Excel Sheets en chunks (taille={CHUNK_SIZE}, chevauchement={CHUNK_OVERLAP})...")
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=CHUNK_SIZE,
            chunk_overlap=CHUNK_OVERLAP,
            length_function=len, # Important: mesure en caractères
            add_start_index=True, # Ajoute la position de début du chunk dans le document original
        )

        all_chunks = []
        doc_counter = 0
        for doc in documents:
            # Validation du document en entrée via ChunkIn
            try:
                validated_doc = ChunkIn(
                    page_content=doc["page_content"],
                    metadata=doc.get("metadata", {})
                )
            except Exception as e:
                logfire.warn(f"  Document ignoré (validation ChunkIn échouée): {e}")
                doc_counter += 1
                continue

            # Convertit notre format de document en format Langchain Document pour le splitter
            langchain_doc = Document(page_content=validated_doc.page_content, metadata=validated_doc.metadata)
            chunks = text_splitter.split_documents([langchain_doc])
            logfire.info(f"  Document '{validated_doc.metadata.get('filename', 'N/A')}' découpé en {len(chunks)} chunks.")

            # Enrichit chaque chunk avec des métadonnées supplémentaires
            for i, chunk in enumerate(chunks):
                try:
                    embedding_in = EmbeddingIn(
                        id=f"{doc_counter}_{i}",
                        text=chunk.page_content,
                        metadata={
                            **chunk.metadata, # Métadonnées héritées du document (source, category, etc.)
                            "chunk_id_in_doc": i, # Position du chunk dans son document d'origine
                            "start_index": chunk.metadata.get("start_index", -1) # Position de début (en caractères)
                        }
                    )
                    all_chunks.append(embedding_in.model_dump())
                except Exception as e:
                    logfire.warn(f"  Chunk {doc_counter}_{i} ignoré (validation EmbeddingIn échouée): {e}")
            doc_counter += 1

        logfire.info(f"Total de {len(all_chunks)} chunks créés.")
        return all_chunks

    def _generate_embeddings(self, chunks: List[Dict[str, any]]) -> Optional[np.ndarray]:
        """Génère les embeddings (tokenization + vectorisation) pour une liste de chunks via l'API Mistral."""
        if not MISTRAL_API_KEY:
            logfire.error("Impossible de générer les embeddings: MISTRAL_API_KEY manquante.")
            return None
        if not chunks:
            logfire.warn("Aucun chunk fourni pour générer les embeddings.")
            return None

        logfire.info(f"Génération des embeddings pour {len(chunks)} chunks (modèle: {EMBEDDING_MODEL})...")
        all_embeddings = []
        total_batches = (len(chunks) + EMBEDDING_BATCH_SIZE - 1) // EMBEDDING_BATCH_SIZE

        for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
            batch_num = (i // EMBEDDING_BATCH_SIZE) + 1
            batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE]
            texts_to_embed = [chunk["text"] for chunk in batch_chunks]

            logfire.info(f"  Traitement du lot {batch_num}/{total_batches} ({len(texts_to_embed)} chunks)")
            try:
                response = self.mistral_client.embeddings.create(
                    model=EMBEDDING_MODEL,
                    inputs=texts_to_embed
                )
                batch_embeddings = [data.embedding for data in response.data]
                all_embeddings.extend(batch_embeddings)
            except SDKError as e:
                logfire.error(f"Erreur API Mistral lors de la génération d'embeddings (lot {batch_num}): {e}")
                logfire.error(f"  Détails: Status Code={e.status_code}, Body={e.body}")
            except Exception as e:
                logfire.error(f"Erreur inattendue lors de la génération d'embeddings (lot {batch_num}): {e}")
                # Gérer l'erreur: ici on ajoute des vecteurs nuls pour ne pas bloquer
                num_failed = len(texts_to_embed)
                if all_embeddings: # Si on a déjà des embeddings, on prend la dimension du premier
                    dim = len(all_embeddings[0])
                else: # Sinon, on ne peut pas déterminer la dimension, on saute ce lot
                     logfire.error("Impossible de déterminer la dimension des embeddings, saut du lot.")
                     continue
                logfire.warn(f"Ajout de {num_failed} vecteurs nuls de dimension {dim} pour le lot échoué.")
                all_embeddings.extend([np.zeros(dim, dtype='float32')] * num_failed)

        if not all_embeddings:
            logfire.error("Aucun embedding n'a pu être généré.")
            return None

        embeddings_array = np.array(all_embeddings).astype('float32')
        logfire.info(f"Embeddings générés avec succès. Shape: {embeddings_array.shape}")
        return embeddings_array

    def build_index(self, documents: List[Dict[str, any]]):
        """Construit l'index Faiss à partir des documents."""
        if not documents:
            logfire.warn("Aucun document fourni pour construire l'index.")
            return

        # 1. Découper en chunks
        self.document_chunks = self._split_documents_to_chunks(documents)
        if not self.document_chunks:
            logfire.error("Le découpage n'a produit aucun chunk. Impossible de construire l'index.")
            return

        # 2. Générer les embeddings
        embeddings = self._generate_embeddings(self.document_chunks)
        if embeddings is None or embeddings.shape[0] != len(self.document_chunks):
            logfire.error("Problème de génération d'embeddings. Le nombre d'embeddings ne correspond pas au nombre de chunks.")
            # Nettoyer pour éviter un état incohérent
            self.document_chunks = []
            self.index = None
            # Supprimer les fichiers potentiellement corrompus
            if os.path.exists(FAISS_INDEX_FILE): os.remove(FAISS_INDEX_FILE)
            if os.path.exists(DOCUMENT_CHUNKS_FILE): os.remove(DOCUMENT_CHUNKS_FILE)
            return

        # 3. Valider chaque paire (chunk, vecteur) via VectordbIn avant l'ajout dans Faiss
        valid_indices = []
        for idx, (chunk, vector) in enumerate(zip(self.document_chunks, embeddings.tolist())):
            try:
                VectordbIn(chunk_id=chunk["id"], vector=vector)
                valid_indices.append(idx)
            except Exception as e:
                logfire.warn(f"  Chunk '{chunk.get('id')}' ignoré (validation VectordbIn échouée): {e}")

        if not valid_indices:
            logfire.error("Aucun vecteur valide après validation. Impossible de construire l'index.")
            return

        # Filtrer les chunks et embeddings valides
        self.document_chunks = [self.document_chunks[i] for i in valid_indices]
        embeddings = embeddings[valid_indices]

        # 4. Créer l'index Faiss optimisé pour la similarité cosinus
        dimension = embeddings.shape[1]
        logfire.info(f"Création de l'index Faiss optimisé pour la similarité cosinus avec dimension {dimension}...")

        # Normaliser les embeddings pour la similarité cosinus
        faiss.normalize_L2(embeddings)

        # Créer un index pour la similarité cosinus (IndexFlatIP = produit scalaire)
        self.index = faiss.IndexFlatIP(dimension)
        self.index.add(embeddings)
        logfire.info(f"Index Faiss créé avec {self.index.ntotal} vecteurs.")

        # 5. Sauvegarder l'index et les chunks
        self._save_index_and_chunks()

    def _save_index_and_chunks(self):
        """Sauvegarde l'index Faiss et la liste des chunks."""
        if self.index is None or not self.document_chunks:
            logfire.warn("Tentative de sauvegarde d'un index ou de chunks vides.")
            return

        os.makedirs(os.path.dirname(FAISS_INDEX_FILE), exist_ok=True)
        os.makedirs(os.path.dirname(DOCUMENT_CHUNKS_FILE), exist_ok=True)

        try:
            logfire.info(f"Sauvegarde de l'index Faiss dans {FAISS_INDEX_FILE}...")
            faiss.write_index(self.index, FAISS_INDEX_FILE)
            logfire.info(f"Sauvegarde des chunks dans {DOCUMENT_CHUNKS_FILE}...")
            with open(DOCUMENT_CHUNKS_FILE, 'wb') as f:
                pickle.dump(self.document_chunks, f)
            logfire.info("Index et chunks sauvegardés avec succès.")
        except Exception as e:
            logfire.error(f"Erreur lors de la sauvegarde de l'index/chunks: {e}")

    def generate_rag_response(self, question: str, context_results: List[Dict[str, any]],
                              system_prompt_template: str, model: str,
                              temperature: float = TEMPERATURE
                              #, top_p: float = TOP_P
                                ) -> str:
        """
        Génère une réponse LLM via l'API Mistral en injectant le contexte RAG dans le prompt.

        Args:
            question: La question de l'utilisateur.
            context_results: Liste de chunks retournés par search().
            system_prompt_template: Template avec les placeholders {context_str} et {question}.
            model: Nom du modèle Mistral à utiliser.
            temperature: Température de génération.
            top_p: Nucleus sampling (activable si besoin, désactivé par défaut via temperature).

        Returns:
            La réponse textuelle générée par le LLM.
        """
        if context_results:
            context_str = "\n\n---\n\n".join([
                f"Source: {res['metadata'].get('source', 'Inconnue')} (Score: {res['score']:.1f}%)\nContenu: {res['text']}"
                for res in context_results
            ])
        else:
            context_str = "Aucune information pertinente trouvée dans la base de connaissances pour cette question."
            logfire.warn(f"Aucun contexte trouvé pour la question: '{question}'")

        final_prompt = system_prompt_template.format(context_str=context_str, question=question)
        messages = [{"role": "user", "content": final_prompt}]

        try:
            logfire.info(f"Appel LLM Mistral (modèle='{model}') pour la question: '{question}'")
            max_retries = 6
            wait_time = 10  # Attente initiale en secondes
            for attempt in range(max_retries):
                try:
                    response = self.mistral_client.chat.complete(
                        model=model,
                        messages=messages,
                        temperature=temperature,
                        # top_p=top_p,  # Décommenter pour activer le nucleus sampling
                    )
                    break  # Succès
                except SDKError as e:
                    if getattr(e, "status_code", None) == 429 and attempt < max_retries - 1:
                        logfire.warning(
                            f"Rate limit 429 dans generate_rag_response (tentative {attempt + 1}/{max_retries}). "
                            f"Attente {wait_time}s..."
                        )
                        time.sleep(wait_time)
                        wait_time = min(wait_time * 2, 120)  # Backoff exponentiel, max 120s
                    else:
                        raise
            if response.choices:
                logfire.info("Réponse LLM reçue.")
                return response.choices[0].message.content
            else:
                logfire.warn("L'API Mistral n'a pas retourné de choix valide.")
                return "Désolé, je n'ai pas pu générer de réponse valide pour le moment."
        except SDKError as e:
            logfire.error(f"Erreur API Mistral lors de l'appel LLM: {e}")
            return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard."
        except Exception as e:
            logfire.error(f"Erreur inattendue lors de l'appel LLM: {e}")
            return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard."

    def search(self, query_text: str, k: int = 5, min_score: float = None) -> List[Dict[str, any]]:
        """
        Recherche les k chunks les plus pertinents pour une requête.

        Args:
            query_text: Texte de la requête
            k: Nombre de résultats à retourner
            min_score: Score minimum (entre 0 et 1) pour inclure un résultat

        Returns:
            Liste des chunks pertinents avec leurs scores
        """
        if self.index is None or not self.document_chunks:
            logfire.warn("Recherche impossible: l'index Faiss n'est pas chargé ou est vide.")
            return []
        if not MISTRAL_API_KEY:
            logfire.error("Recherche impossible: MISTRAL_API_KEY manquante pour générer l'embedding de la requête.")
            return []

        logfire.info(f"Recherche des {k} chunks les plus pertinents pour: '{query_text}'")
        try:
            # 1. Générer l'embedding de la requête
            response = self.mistral_client.embeddings.create(
                model=EMBEDDING_MODEL,
                inputs=[query_text]
            )
            query_embedding = np.array([response.data[0].embedding]).astype('float32')

            # Normaliser l'embedding de la requête pour la similarité cosinus
            faiss.normalize_L2(query_embedding)

            # 2. Rechercher dans l'index Faiss
            # Pour IndexFlatIP: scores = produit scalaire (plus grand = meilleur)
            # indices: index des chunks correspondants dans self.document_chunks
            # Demander plus de résultats si un score minimum est spécifié
            search_k = k * 3 if min_score is not None else k
            scores, indices = self.index.search(query_embedding, search_k)

            # 3. Formater les résultats
            results = []
            if indices.size > 0: # Vérifier s'il y a des résultats
                for i, idx in enumerate(indices[0]):
                    if 0 <= idx < len(self.document_chunks): # Vérifier la validité de l'index
                        chunk = self.document_chunks[idx]
                        # Convertir le score en similarité (0-1)
                        # Pour IndexFlatIP avec vecteurs normalisés, le score est déjà entre -1 et 1
                        # On le convertit en pourcentage (0-100%)
                        raw_score = float(scores[0][i])
                        similarity = raw_score * 100

                        # Filtrer les résultats en fonction du score minimum
                        # Le min_score est entre 0 et 1, mais similarity est en pourcentage (0-100)
                        min_score_percent = min_score * 100 if min_score is not None else 0
                        if min_score is not None and similarity < min_score_percent:
                            logfire.debug(f"Document filtré (score {similarity:.2f}% < minimum {min_score_percent:.2f}%)")
                            continue

                        results.append({
                            "score": similarity, # Score de similarité en pourcentage
                            "raw_score": raw_score, # Score brut pour débogage
                            "text": chunk["text"],
                            "metadata": chunk["metadata"] # Contient source, category, chunk_id_in_doc, start_index etc.
                        })
                    else:
                        logfire.warn(f"Index Faiss {idx} hors limites (taille des chunks: {len(self.document_chunks)}).")

            # Trier par score (similarité la plus élevée en premier)
            results.sort(key=lambda x: x["score"], reverse=True)

            # Limiter au nombre demandé (k) si nécessaire
            if len(results) > k:
                results = results[:k]

            if min_score is not None:
                min_score_percent = min_score * 100
                logfire.info(f"{len(results)} chunks pertinents trouvés (score minimum: {min_score_percent:.2f}%).")
            else:
                logfire.info(f"{len(results)} chunks pertinents trouvés.")

            return results

        except SDKError as e:
            logfire.error(f"Erreur API Mistral lors de la génération de l'embedding de la requête: {e}")
            logfire.error(f"  Détails: Status Code={e.status_code}, Body={e.body}")
            return []
        except Exception as e:
            logfire.error(f"Erreur inattendue lors de la recherche: {e}")
            return []