Spaces:
Running
Running
| # utils/vector_store.py | |
| import os | |
| import time | |
| import pickle | |
| import faiss | |
| import numpy as np | |
| import logfire | |
| from typing import List, Dict, Optional, Any | |
| from mistralai import Mistral, SDKError | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.documents import Document # Utilisé pour le format attendu par le splitter | |
| from pydantic import BaseModel, Field, field_validator | |
| from .config import ( | |
| MISTRAL_API_KEY, EMBEDDING_MODEL, EMBEDDING_BATCH_SIZE, | |
| FAISS_INDEX_FILE, DOCUMENT_CHUNKS_FILE, CHUNK_SIZE, CHUNK_OVERLAP, | |
| TEMPERATURE, TOP_P | |
| ) | |
| # --------------------------- | |
| # Schemas Pydantic V2 | |
| # --------------------------- | |
| class ChunkIn(BaseModel): | |
| """Représente les données d'entrée pour le découpage d'un document en chunks. | |
| Correspond à la structure des documents chargés (page_content + metadata).""" | |
| page_content: str = Field(..., min_length=1, description="Contenu textuel du document") | |
| metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées du document (filename, source, category...)") | |
| def text_not_empty(cls, v): | |
| if not v or not v.strip(): | |
| raise ValueError('Le texte du document ne peut pas être vide') | |
| return v.strip() | |
| class EmbeddingIn(BaseModel): | |
| """Représente un chunk découpé prêt pour la génération d'embeddings. | |
| Correspond à la structure des dicts produits par _split_documents_to_chunks.""" | |
| id: str = Field(..., min_length=1, description="Identifiant unique du chunk (doc_index_chunk_index)") | |
| text: str = Field(..., min_length=1, description="Texte du chunk") | |
| metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées héritées + chunk_id_in_doc + start_index") | |
| def chunk_text_not_empty(cls, v): | |
| if not v or not v.strip(): | |
| raise ValueError('Le texte du chunk ne peut pas être vide') | |
| return v.strip() | |
| class VectordbIn(BaseModel): | |
| """Représente les données d'entrée pour l'ajout d'un chunk et de son embedding dans l'index Faiss.""" | |
| chunk_id: str = Field(..., min_length=1, description="Identifiant du chunk correspondant (clé 'id' du chunk)") | |
| vector: List[float] = Field(..., min_length=1, description="Vecteur d'embedding") | |
| class VectorStoreManager: | |
| """Gère la création, le chargement et la recherche dans un index Faiss.""" | |
| def __init__(self): | |
| self.index: Optional[faiss.Index] = None | |
| self.document_chunks: List[Dict[str, any]] = [] | |
| self.mistral_client = Mistral(api_key=MISTRAL_API_KEY) | |
| self._load_index_and_chunks() | |
| def _load_index_and_chunks(self): | |
| """Charge l'index Faiss et les chunks si les fichiers existent.""" | |
| if os.path.exists(FAISS_INDEX_FILE) and os.path.exists(DOCUMENT_CHUNKS_FILE): | |
| try: | |
| logfire.info(f"Chargement de l'index Faiss depuis {FAISS_INDEX_FILE}...") | |
| self.index = faiss.read_index(FAISS_INDEX_FILE) | |
| logfire.info(f"Chargement des chunks depuis {DOCUMENT_CHUNKS_FILE}...") | |
| with open(DOCUMENT_CHUNKS_FILE, 'rb') as f: | |
| self.document_chunks = pickle.load(f) | |
| logfire.info(f"Index ({self.index.ntotal} vecteurs) et {len(self.document_chunks)} chunks chargés.") | |
| except Exception as e: | |
| logfire.error(f"Erreur lors du chargement de l'index/chunks: {e}") | |
| self.index = None | |
| self.document_chunks = [] | |
| else: | |
| logfire.warn("Fichiers d'index Faiss ou de chunks non trouvés. L'index est vide.") | |
| def _split_documents_to_chunks(self, documents: List[Dict[str, any]]) -> List[Dict[str, any]]: | |
| """Découpe les documents en chunks avec métadonnées.""" | |
| logfire.info(f"Découpage de {len(documents)} documents/Excel Sheets en chunks (taille={CHUNK_SIZE}, chevauchement={CHUNK_OVERLAP})...") | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| length_function=len, # Important: mesure en caractères | |
| add_start_index=True, # Ajoute la position de début du chunk dans le document original | |
| ) | |
| all_chunks = [] | |
| doc_counter = 0 | |
| for doc in documents: | |
| # Validation du document en entrée via ChunkIn | |
| try: | |
| validated_doc = ChunkIn( | |
| page_content=doc["page_content"], | |
| metadata=doc.get("metadata", {}) | |
| ) | |
| except Exception as e: | |
| logfire.warn(f" Document ignoré (validation ChunkIn échouée): {e}") | |
| doc_counter += 1 | |
| continue | |
| # Convertit notre format de document en format Langchain Document pour le splitter | |
| langchain_doc = Document(page_content=validated_doc.page_content, metadata=validated_doc.metadata) | |
| chunks = text_splitter.split_documents([langchain_doc]) | |
| logfire.info(f" Document '{validated_doc.metadata.get('filename', 'N/A')}' découpé en {len(chunks)} chunks.") | |
| # Enrichit chaque chunk avec des métadonnées supplémentaires | |
| for i, chunk in enumerate(chunks): | |
| try: | |
| embedding_in = EmbeddingIn( | |
| id=f"{doc_counter}_{i}", | |
| text=chunk.page_content, | |
| metadata={ | |
| **chunk.metadata, # Métadonnées héritées du document (source, category, etc.) | |
| "chunk_id_in_doc": i, # Position du chunk dans son document d'origine | |
| "start_index": chunk.metadata.get("start_index", -1) # Position de début (en caractères) | |
| } | |
| ) | |
| all_chunks.append(embedding_in.model_dump()) | |
| except Exception as e: | |
| logfire.warn(f" Chunk {doc_counter}_{i} ignoré (validation EmbeddingIn échouée): {e}") | |
| doc_counter += 1 | |
| logfire.info(f"Total de {len(all_chunks)} chunks créés.") | |
| return all_chunks | |
| def _generate_embeddings(self, chunks: List[Dict[str, any]]) -> Optional[np.ndarray]: | |
| """Génère les embeddings (tokenization + vectorisation) pour une liste de chunks via l'API Mistral.""" | |
| if not MISTRAL_API_KEY: | |
| logfire.error("Impossible de générer les embeddings: MISTRAL_API_KEY manquante.") | |
| return None | |
| if not chunks: | |
| logfire.warn("Aucun chunk fourni pour générer les embeddings.") | |
| return None | |
| logfire.info(f"Génération des embeddings pour {len(chunks)} chunks (modèle: {EMBEDDING_MODEL})...") | |
| all_embeddings = [] | |
| total_batches = (len(chunks) + EMBEDDING_BATCH_SIZE - 1) // EMBEDDING_BATCH_SIZE | |
| for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE): | |
| batch_num = (i // EMBEDDING_BATCH_SIZE) + 1 | |
| batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE] | |
| texts_to_embed = [chunk["text"] for chunk in batch_chunks] | |
| logfire.info(f" Traitement du lot {batch_num}/{total_batches} ({len(texts_to_embed)} chunks)") | |
| try: | |
| response = self.mistral_client.embeddings.create( | |
| model=EMBEDDING_MODEL, | |
| inputs=texts_to_embed | |
| ) | |
| batch_embeddings = [data.embedding for data in response.data] | |
| all_embeddings.extend(batch_embeddings) | |
| except SDKError as e: | |
| logfire.error(f"Erreur API Mistral lors de la génération d'embeddings (lot {batch_num}): {e}") | |
| logfire.error(f" Détails: Status Code={e.status_code}, Body={e.body}") | |
| except Exception as e: | |
| logfire.error(f"Erreur inattendue lors de la génération d'embeddings (lot {batch_num}): {e}") | |
| # Gérer l'erreur: ici on ajoute des vecteurs nuls pour ne pas bloquer | |
| num_failed = len(texts_to_embed) | |
| if all_embeddings: # Si on a déjà des embeddings, on prend la dimension du premier | |
| dim = len(all_embeddings[0]) | |
| else: # Sinon, on ne peut pas déterminer la dimension, on saute ce lot | |
| logfire.error("Impossible de déterminer la dimension des embeddings, saut du lot.") | |
| continue | |
| logfire.warn(f"Ajout de {num_failed} vecteurs nuls de dimension {dim} pour le lot échoué.") | |
| all_embeddings.extend([np.zeros(dim, dtype='float32')] * num_failed) | |
| if not all_embeddings: | |
| logfire.error("Aucun embedding n'a pu être généré.") | |
| return None | |
| embeddings_array = np.array(all_embeddings).astype('float32') | |
| logfire.info(f"Embeddings générés avec succès. Shape: {embeddings_array.shape}") | |
| return embeddings_array | |
| def build_index(self, documents: List[Dict[str, any]]): | |
| """Construit l'index Faiss à partir des documents.""" | |
| if not documents: | |
| logfire.warn("Aucun document fourni pour construire l'index.") | |
| return | |
| # 1. Découper en chunks | |
| self.document_chunks = self._split_documents_to_chunks(documents) | |
| if not self.document_chunks: | |
| logfire.error("Le découpage n'a produit aucun chunk. Impossible de construire l'index.") | |
| return | |
| # 2. Générer les embeddings | |
| embeddings = self._generate_embeddings(self.document_chunks) | |
| if embeddings is None or embeddings.shape[0] != len(self.document_chunks): | |
| logfire.error("Problème de génération d'embeddings. Le nombre d'embeddings ne correspond pas au nombre de chunks.") | |
| # Nettoyer pour éviter un état incohérent | |
| self.document_chunks = [] | |
| self.index = None | |
| # Supprimer les fichiers potentiellement corrompus | |
| if os.path.exists(FAISS_INDEX_FILE): os.remove(FAISS_INDEX_FILE) | |
| if os.path.exists(DOCUMENT_CHUNKS_FILE): os.remove(DOCUMENT_CHUNKS_FILE) | |
| return | |
| # 3. Valider chaque paire (chunk, vecteur) via VectordbIn avant l'ajout dans Faiss | |
| valid_indices = [] | |
| for idx, (chunk, vector) in enumerate(zip(self.document_chunks, embeddings.tolist())): | |
| try: | |
| VectordbIn(chunk_id=chunk["id"], vector=vector) | |
| valid_indices.append(idx) | |
| except Exception as e: | |
| logfire.warn(f" Chunk '{chunk.get('id')}' ignoré (validation VectordbIn échouée): {e}") | |
| if not valid_indices: | |
| logfire.error("Aucun vecteur valide après validation. Impossible de construire l'index.") | |
| return | |
| # Filtrer les chunks et embeddings valides | |
| self.document_chunks = [self.document_chunks[i] for i in valid_indices] | |
| embeddings = embeddings[valid_indices] | |
| # 4. Créer l'index Faiss optimisé pour la similarité cosinus | |
| dimension = embeddings.shape[1] | |
| logfire.info(f"Création de l'index Faiss optimisé pour la similarité cosinus avec dimension {dimension}...") | |
| # Normaliser les embeddings pour la similarité cosinus | |
| faiss.normalize_L2(embeddings) | |
| # Créer un index pour la similarité cosinus (IndexFlatIP = produit scalaire) | |
| self.index = faiss.IndexFlatIP(dimension) | |
| self.index.add(embeddings) | |
| logfire.info(f"Index Faiss créé avec {self.index.ntotal} vecteurs.") | |
| # 5. Sauvegarder l'index et les chunks | |
| self._save_index_and_chunks() | |
| def _save_index_and_chunks(self): | |
| """Sauvegarde l'index Faiss et la liste des chunks.""" | |
| if self.index is None or not self.document_chunks: | |
| logfire.warn("Tentative de sauvegarde d'un index ou de chunks vides.") | |
| return | |
| os.makedirs(os.path.dirname(FAISS_INDEX_FILE), exist_ok=True) | |
| os.makedirs(os.path.dirname(DOCUMENT_CHUNKS_FILE), exist_ok=True) | |
| try: | |
| logfire.info(f"Sauvegarde de l'index Faiss dans {FAISS_INDEX_FILE}...") | |
| faiss.write_index(self.index, FAISS_INDEX_FILE) | |
| logfire.info(f"Sauvegarde des chunks dans {DOCUMENT_CHUNKS_FILE}...") | |
| with open(DOCUMENT_CHUNKS_FILE, 'wb') as f: | |
| pickle.dump(self.document_chunks, f) | |
| logfire.info("Index et chunks sauvegardés avec succès.") | |
| except Exception as e: | |
| logfire.error(f"Erreur lors de la sauvegarde de l'index/chunks: {e}") | |
| def generate_rag_response(self, question: str, context_results: List[Dict[str, any]], | |
| system_prompt_template: str, model: str, | |
| temperature: float = TEMPERATURE | |
| #, top_p: float = TOP_P | |
| ) -> str: | |
| """ | |
| Génère une réponse LLM via l'API Mistral en injectant le contexte RAG dans le prompt. | |
| Args: | |
| question: La question de l'utilisateur. | |
| context_results: Liste de chunks retournés par search(). | |
| system_prompt_template: Template avec les placeholders {context_str} et {question}. | |
| model: Nom du modèle Mistral à utiliser. | |
| temperature: Température de génération. | |
| top_p: Nucleus sampling (activable si besoin, désactivé par défaut via temperature). | |
| Returns: | |
| La réponse textuelle générée par le LLM. | |
| """ | |
| if context_results: | |
| context_str = "\n\n---\n\n".join([ | |
| f"Source: {res['metadata'].get('source', 'Inconnue')} (Score: {res['score']:.1f}%)\nContenu: {res['text']}" | |
| for res in context_results | |
| ]) | |
| else: | |
| context_str = "Aucune information pertinente trouvée dans la base de connaissances pour cette question." | |
| logfire.warn(f"Aucun contexte trouvé pour la question: '{question}'") | |
| final_prompt = system_prompt_template.format(context_str=context_str, question=question) | |
| messages = [{"role": "user", "content": final_prompt}] | |
| try: | |
| logfire.info(f"Appel LLM Mistral (modèle='{model}') pour la question: '{question}'") | |
| max_retries = 6 | |
| wait_time = 10 # Attente initiale en secondes | |
| for attempt in range(max_retries): | |
| try: | |
| response = self.mistral_client.chat.complete( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| # top_p=top_p, # Décommenter pour activer le nucleus sampling | |
| ) | |
| break # Succès | |
| except SDKError as e: | |
| if getattr(e, "status_code", None) == 429 and attempt < max_retries - 1: | |
| logfire.warning( | |
| f"Rate limit 429 dans generate_rag_response (tentative {attempt + 1}/{max_retries}). " | |
| f"Attente {wait_time}s..." | |
| ) | |
| time.sleep(wait_time) | |
| wait_time = min(wait_time * 2, 120) # Backoff exponentiel, max 120s | |
| else: | |
| raise | |
| if response.choices: | |
| logfire.info("Réponse LLM reçue.") | |
| return response.choices[0].message.content | |
| else: | |
| logfire.warn("L'API Mistral n'a pas retourné de choix valide.") | |
| return "Désolé, je n'ai pas pu générer de réponse valide pour le moment." | |
| except SDKError as e: | |
| logfire.error(f"Erreur API Mistral lors de l'appel LLM: {e}") | |
| return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard." | |
| except Exception as e: | |
| logfire.error(f"Erreur inattendue lors de l'appel LLM: {e}") | |
| return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard." | |
| def search(self, query_text: str, k: int = 5, min_score: float = None) -> List[Dict[str, any]]: | |
| """ | |
| Recherche les k chunks les plus pertinents pour une requête. | |
| Args: | |
| query_text: Texte de la requête | |
| k: Nombre de résultats à retourner | |
| min_score: Score minimum (entre 0 et 1) pour inclure un résultat | |
| Returns: | |
| Liste des chunks pertinents avec leurs scores | |
| """ | |
| if self.index is None or not self.document_chunks: | |
| logfire.warn("Recherche impossible: l'index Faiss n'est pas chargé ou est vide.") | |
| return [] | |
| if not MISTRAL_API_KEY: | |
| logfire.error("Recherche impossible: MISTRAL_API_KEY manquante pour générer l'embedding de la requête.") | |
| return [] | |
| logfire.info(f"Recherche des {k} chunks les plus pertinents pour: '{query_text}'") | |
| try: | |
| # 1. Générer l'embedding de la requête | |
| response = self.mistral_client.embeddings.create( | |
| model=EMBEDDING_MODEL, | |
| inputs=[query_text] | |
| ) | |
| query_embedding = np.array([response.data[0].embedding]).astype('float32') | |
| # Normaliser l'embedding de la requête pour la similarité cosinus | |
| faiss.normalize_L2(query_embedding) | |
| # 2. Rechercher dans l'index Faiss | |
| # Pour IndexFlatIP: scores = produit scalaire (plus grand = meilleur) | |
| # indices: index des chunks correspondants dans self.document_chunks | |
| # Demander plus de résultats si un score minimum est spécifié | |
| search_k = k * 3 if min_score is not None else k | |
| scores, indices = self.index.search(query_embedding, search_k) | |
| # 3. Formater les résultats | |
| results = [] | |
| if indices.size > 0: # Vérifier s'il y a des résultats | |
| for i, idx in enumerate(indices[0]): | |
| if 0 <= idx < len(self.document_chunks): # Vérifier la validité de l'index | |
| chunk = self.document_chunks[idx] | |
| # Convertir le score en similarité (0-1) | |
| # Pour IndexFlatIP avec vecteurs normalisés, le score est déjà entre -1 et 1 | |
| # On le convertit en pourcentage (0-100%) | |
| raw_score = float(scores[0][i]) | |
| similarity = raw_score * 100 | |
| # Filtrer les résultats en fonction du score minimum | |
| # Le min_score est entre 0 et 1, mais similarity est en pourcentage (0-100) | |
| min_score_percent = min_score * 100 if min_score is not None else 0 | |
| if min_score is not None and similarity < min_score_percent: | |
| logfire.debug(f"Document filtré (score {similarity:.2f}% < minimum {min_score_percent:.2f}%)") | |
| continue | |
| results.append({ | |
| "score": similarity, # Score de similarité en pourcentage | |
| "raw_score": raw_score, # Score brut pour débogage | |
| "text": chunk["text"], | |
| "metadata": chunk["metadata"] # Contient source, category, chunk_id_in_doc, start_index etc. | |
| }) | |
| else: | |
| logfire.warn(f"Index Faiss {idx} hors limites (taille des chunks: {len(self.document_chunks)}).") | |
| # Trier par score (similarité la plus élevée en premier) | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| # Limiter au nombre demandé (k) si nécessaire | |
| if len(results) > k: | |
| results = results[:k] | |
| if min_score is not None: | |
| min_score_percent = min_score * 100 | |
| logfire.info(f"{len(results)} chunks pertinents trouvés (score minimum: {min_score_percent:.2f}%).") | |
| else: | |
| logfire.info(f"{len(results)} chunks pertinents trouvés.") | |
| return results | |
| except SDKError as e: | |
| logfire.error(f"Erreur API Mistral lors de la génération de l'embedding de la requête: {e}") | |
| logfire.error(f" Détails: Status Code={e.status_code}, Body={e.body}") | |
| return [] | |
| except Exception as e: | |
| logfire.error(f"Erreur inattendue lors de la recherche: {e}") | |
| return [] |