Spaces:
Sleeping
Sleeping
File size: 20,770 Bytes
7cb1544 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 | # utils/vector_store.py
import os
import time
import pickle
import faiss
import numpy as np
import logfire
from typing import List, Dict, Optional, Any
from mistralai import Mistral, SDKError
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document # Utilisé pour le format attendu par le splitter
from pydantic import BaseModel, Field, field_validator
from .config import (
MISTRAL_API_KEY, EMBEDDING_MODEL, EMBEDDING_BATCH_SIZE,
FAISS_INDEX_FILE, DOCUMENT_CHUNKS_FILE, CHUNK_SIZE, CHUNK_OVERLAP,
TEMPERATURE, TOP_P
)
# ---------------------------
# Schemas Pydantic V2
# ---------------------------
class ChunkIn(BaseModel):
"""Représente les données d'entrée pour le découpage d'un document en chunks.
Correspond à la structure des documents chargés (page_content + metadata)."""
page_content: str = Field(..., min_length=1, description="Contenu textuel du document")
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées du document (filename, source, category...)")
@field_validator('page_content')
@classmethod
def text_not_empty(cls, v):
if not v or not v.strip():
raise ValueError('Le texte du document ne peut pas être vide')
return v.strip()
class EmbeddingIn(BaseModel):
"""Représente un chunk découpé prêt pour la génération d'embeddings.
Correspond à la structure des dicts produits par _split_documents_to_chunks."""
id: str = Field(..., min_length=1, description="Identifiant unique du chunk (doc_index_chunk_index)")
text: str = Field(..., min_length=1, description="Texte du chunk")
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées héritées + chunk_id_in_doc + start_index")
@field_validator('text')
@classmethod
def chunk_text_not_empty(cls, v):
if not v or not v.strip():
raise ValueError('Le texte du chunk ne peut pas être vide')
return v.strip()
class VectordbIn(BaseModel):
"""Représente les données d'entrée pour l'ajout d'un chunk et de son embedding dans l'index Faiss."""
chunk_id: str = Field(..., min_length=1, description="Identifiant du chunk correspondant (clé 'id' du chunk)")
vector: List[float] = Field(..., min_length=1, description="Vecteur d'embedding")
class VectorStoreManager:
"""Gère la création, le chargement et la recherche dans un index Faiss."""
def __init__(self):
self.index: Optional[faiss.Index] = None
self.document_chunks: List[Dict[str, any]] = []
self.mistral_client = Mistral(api_key=MISTRAL_API_KEY)
self._load_index_and_chunks()
def _load_index_and_chunks(self):
"""Charge l'index Faiss et les chunks si les fichiers existent."""
if os.path.exists(FAISS_INDEX_FILE) and os.path.exists(DOCUMENT_CHUNKS_FILE):
try:
logfire.info(f"Chargement de l'index Faiss depuis {FAISS_INDEX_FILE}...")
self.index = faiss.read_index(FAISS_INDEX_FILE)
logfire.info(f"Chargement des chunks depuis {DOCUMENT_CHUNKS_FILE}...")
with open(DOCUMENT_CHUNKS_FILE, 'rb') as f:
self.document_chunks = pickle.load(f)
logfire.info(f"Index ({self.index.ntotal} vecteurs) et {len(self.document_chunks)} chunks chargés.")
except Exception as e:
logfire.error(f"Erreur lors du chargement de l'index/chunks: {e}")
self.index = None
self.document_chunks = []
else:
logfire.warn("Fichiers d'index Faiss ou de chunks non trouvés. L'index est vide.")
def _split_documents_to_chunks(self, documents: List[Dict[str, any]]) -> List[Dict[str, any]]:
"""Découpe les documents en chunks avec métadonnées."""
logfire.info(f"Découpage de {len(documents)} documents/Excel Sheets en chunks (taille={CHUNK_SIZE}, chevauchement={CHUNK_OVERLAP})...")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len, # Important: mesure en caractères
add_start_index=True, # Ajoute la position de début du chunk dans le document original
)
all_chunks = []
doc_counter = 0
for doc in documents:
# Validation du document en entrée via ChunkIn
try:
validated_doc = ChunkIn(
page_content=doc["page_content"],
metadata=doc.get("metadata", {})
)
except Exception as e:
logfire.warn(f" Document ignoré (validation ChunkIn échouée): {e}")
doc_counter += 1
continue
# Convertit notre format de document en format Langchain Document pour le splitter
langchain_doc = Document(page_content=validated_doc.page_content, metadata=validated_doc.metadata)
chunks = text_splitter.split_documents([langchain_doc])
logfire.info(f" Document '{validated_doc.metadata.get('filename', 'N/A')}' découpé en {len(chunks)} chunks.")
# Enrichit chaque chunk avec des métadonnées supplémentaires
for i, chunk in enumerate(chunks):
try:
embedding_in = EmbeddingIn(
id=f"{doc_counter}_{i}",
text=chunk.page_content,
metadata={
**chunk.metadata, # Métadonnées héritées du document (source, category, etc.)
"chunk_id_in_doc": i, # Position du chunk dans son document d'origine
"start_index": chunk.metadata.get("start_index", -1) # Position de début (en caractères)
}
)
all_chunks.append(embedding_in.model_dump())
except Exception as e:
logfire.warn(f" Chunk {doc_counter}_{i} ignoré (validation EmbeddingIn échouée): {e}")
doc_counter += 1
logfire.info(f"Total de {len(all_chunks)} chunks créés.")
return all_chunks
def _generate_embeddings(self, chunks: List[Dict[str, any]]) -> Optional[np.ndarray]:
"""Génère les embeddings (tokenization + vectorisation) pour une liste de chunks via l'API Mistral."""
if not MISTRAL_API_KEY:
logfire.error("Impossible de générer les embeddings: MISTRAL_API_KEY manquante.")
return None
if not chunks:
logfire.warn("Aucun chunk fourni pour générer les embeddings.")
return None
logfire.info(f"Génération des embeddings pour {len(chunks)} chunks (modèle: {EMBEDDING_MODEL})...")
all_embeddings = []
total_batches = (len(chunks) + EMBEDDING_BATCH_SIZE - 1) // EMBEDDING_BATCH_SIZE
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
batch_num = (i // EMBEDDING_BATCH_SIZE) + 1
batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE]
texts_to_embed = [chunk["text"] for chunk in batch_chunks]
logfire.info(f" Traitement du lot {batch_num}/{total_batches} ({len(texts_to_embed)} chunks)")
try:
response = self.mistral_client.embeddings.create(
model=EMBEDDING_MODEL,
inputs=texts_to_embed
)
batch_embeddings = [data.embedding for data in response.data]
all_embeddings.extend(batch_embeddings)
except SDKError as e:
logfire.error(f"Erreur API Mistral lors de la génération d'embeddings (lot {batch_num}): {e}")
logfire.error(f" Détails: Status Code={e.status_code}, Body={e.body}")
except Exception as e:
logfire.error(f"Erreur inattendue lors de la génération d'embeddings (lot {batch_num}): {e}")
# Gérer l'erreur: ici on ajoute des vecteurs nuls pour ne pas bloquer
num_failed = len(texts_to_embed)
if all_embeddings: # Si on a déjà des embeddings, on prend la dimension du premier
dim = len(all_embeddings[0])
else: # Sinon, on ne peut pas déterminer la dimension, on saute ce lot
logfire.error("Impossible de déterminer la dimension des embeddings, saut du lot.")
continue
logfire.warn(f"Ajout de {num_failed} vecteurs nuls de dimension {dim} pour le lot échoué.")
all_embeddings.extend([np.zeros(dim, dtype='float32')] * num_failed)
if not all_embeddings:
logfire.error("Aucun embedding n'a pu être généré.")
return None
embeddings_array = np.array(all_embeddings).astype('float32')
logfire.info(f"Embeddings générés avec succès. Shape: {embeddings_array.shape}")
return embeddings_array
def build_index(self, documents: List[Dict[str, any]]):
"""Construit l'index Faiss à partir des documents."""
if not documents:
logfire.warn("Aucun document fourni pour construire l'index.")
return
# 1. Découper en chunks
self.document_chunks = self._split_documents_to_chunks(documents)
if not self.document_chunks:
logfire.error("Le découpage n'a produit aucun chunk. Impossible de construire l'index.")
return
# 2. Générer les embeddings
embeddings = self._generate_embeddings(self.document_chunks)
if embeddings is None or embeddings.shape[0] != len(self.document_chunks):
logfire.error("Problème de génération d'embeddings. Le nombre d'embeddings ne correspond pas au nombre de chunks.")
# Nettoyer pour éviter un état incohérent
self.document_chunks = []
self.index = None
# Supprimer les fichiers potentiellement corrompus
if os.path.exists(FAISS_INDEX_FILE): os.remove(FAISS_INDEX_FILE)
if os.path.exists(DOCUMENT_CHUNKS_FILE): os.remove(DOCUMENT_CHUNKS_FILE)
return
# 3. Valider chaque paire (chunk, vecteur) via VectordbIn avant l'ajout dans Faiss
valid_indices = []
for idx, (chunk, vector) in enumerate(zip(self.document_chunks, embeddings.tolist())):
try:
VectordbIn(chunk_id=chunk["id"], vector=vector)
valid_indices.append(idx)
except Exception as e:
logfire.warn(f" Chunk '{chunk.get('id')}' ignoré (validation VectordbIn échouée): {e}")
if not valid_indices:
logfire.error("Aucun vecteur valide après validation. Impossible de construire l'index.")
return
# Filtrer les chunks et embeddings valides
self.document_chunks = [self.document_chunks[i] for i in valid_indices]
embeddings = embeddings[valid_indices]
# 4. Créer l'index Faiss optimisé pour la similarité cosinus
dimension = embeddings.shape[1]
logfire.info(f"Création de l'index Faiss optimisé pour la similarité cosinus avec dimension {dimension}...")
# Normaliser les embeddings pour la similarité cosinus
faiss.normalize_L2(embeddings)
# Créer un index pour la similarité cosinus (IndexFlatIP = produit scalaire)
self.index = faiss.IndexFlatIP(dimension)
self.index.add(embeddings)
logfire.info(f"Index Faiss créé avec {self.index.ntotal} vecteurs.")
# 5. Sauvegarder l'index et les chunks
self._save_index_and_chunks()
def _save_index_and_chunks(self):
"""Sauvegarde l'index Faiss et la liste des chunks."""
if self.index is None or not self.document_chunks:
logfire.warn("Tentative de sauvegarde d'un index ou de chunks vides.")
return
os.makedirs(os.path.dirname(FAISS_INDEX_FILE), exist_ok=True)
os.makedirs(os.path.dirname(DOCUMENT_CHUNKS_FILE), exist_ok=True)
try:
logfire.info(f"Sauvegarde de l'index Faiss dans {FAISS_INDEX_FILE}...")
faiss.write_index(self.index, FAISS_INDEX_FILE)
logfire.info(f"Sauvegarde des chunks dans {DOCUMENT_CHUNKS_FILE}...")
with open(DOCUMENT_CHUNKS_FILE, 'wb') as f:
pickle.dump(self.document_chunks, f)
logfire.info("Index et chunks sauvegardés avec succès.")
except Exception as e:
logfire.error(f"Erreur lors de la sauvegarde de l'index/chunks: {e}")
def generate_rag_response(self, question: str, context_results: List[Dict[str, any]],
system_prompt_template: str, model: str,
temperature: float = TEMPERATURE
#, top_p: float = TOP_P
) -> str:
"""
Génère une réponse LLM via l'API Mistral en injectant le contexte RAG dans le prompt.
Args:
question: La question de l'utilisateur.
context_results: Liste de chunks retournés par search().
system_prompt_template: Template avec les placeholders {context_str} et {question}.
model: Nom du modèle Mistral à utiliser.
temperature: Température de génération.
top_p: Nucleus sampling (activable si besoin, désactivé par défaut via temperature).
Returns:
La réponse textuelle générée par le LLM.
"""
if context_results:
context_str = "\n\n---\n\n".join([
f"Source: {res['metadata'].get('source', 'Inconnue')} (Score: {res['score']:.1f}%)\nContenu: {res['text']}"
for res in context_results
])
else:
context_str = "Aucune information pertinente trouvée dans la base de connaissances pour cette question."
logfire.warn(f"Aucun contexte trouvé pour la question: '{question}'")
final_prompt = system_prompt_template.format(context_str=context_str, question=question)
messages = [{"role": "user", "content": final_prompt}]
try:
logfire.info(f"Appel LLM Mistral (modèle='{model}') pour la question: '{question}'")
max_retries = 6
wait_time = 10 # Attente initiale en secondes
for attempt in range(max_retries):
try:
response = self.mistral_client.chat.complete(
model=model,
messages=messages,
temperature=temperature,
# top_p=top_p, # Décommenter pour activer le nucleus sampling
)
break # Succès
except SDKError as e:
if getattr(e, "status_code", None) == 429 and attempt < max_retries - 1:
logfire.warning(
f"Rate limit 429 dans generate_rag_response (tentative {attempt + 1}/{max_retries}). "
f"Attente {wait_time}s..."
)
time.sleep(wait_time)
wait_time = min(wait_time * 2, 120) # Backoff exponentiel, max 120s
else:
raise
if response.choices:
logfire.info("Réponse LLM reçue.")
return response.choices[0].message.content
else:
logfire.warn("L'API Mistral n'a pas retourné de choix valide.")
return "Désolé, je n'ai pas pu générer de réponse valide pour le moment."
except SDKError as e:
logfire.error(f"Erreur API Mistral lors de l'appel LLM: {e}")
return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard."
except Exception as e:
logfire.error(f"Erreur inattendue lors de l'appel LLM: {e}")
return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard."
def search(self, query_text: str, k: int = 5, min_score: float = None) -> List[Dict[str, any]]:
"""
Recherche les k chunks les plus pertinents pour une requête.
Args:
query_text: Texte de la requête
k: Nombre de résultats à retourner
min_score: Score minimum (entre 0 et 1) pour inclure un résultat
Returns:
Liste des chunks pertinents avec leurs scores
"""
if self.index is None or not self.document_chunks:
logfire.warn("Recherche impossible: l'index Faiss n'est pas chargé ou est vide.")
return []
if not MISTRAL_API_KEY:
logfire.error("Recherche impossible: MISTRAL_API_KEY manquante pour générer l'embedding de la requête.")
return []
logfire.info(f"Recherche des {k} chunks les plus pertinents pour: '{query_text}'")
try:
# 1. Générer l'embedding de la requête
response = self.mistral_client.embeddings.create(
model=EMBEDDING_MODEL,
inputs=[query_text]
)
query_embedding = np.array([response.data[0].embedding]).astype('float32')
# Normaliser l'embedding de la requête pour la similarité cosinus
faiss.normalize_L2(query_embedding)
# 2. Rechercher dans l'index Faiss
# Pour IndexFlatIP: scores = produit scalaire (plus grand = meilleur)
# indices: index des chunks correspondants dans self.document_chunks
# Demander plus de résultats si un score minimum est spécifié
search_k = k * 3 if min_score is not None else k
scores, indices = self.index.search(query_embedding, search_k)
# 3. Formater les résultats
results = []
if indices.size > 0: # Vérifier s'il y a des résultats
for i, idx in enumerate(indices[0]):
if 0 <= idx < len(self.document_chunks): # Vérifier la validité de l'index
chunk = self.document_chunks[idx]
# Convertir le score en similarité (0-1)
# Pour IndexFlatIP avec vecteurs normalisés, le score est déjà entre -1 et 1
# On le convertit en pourcentage (0-100%)
raw_score = float(scores[0][i])
similarity = raw_score * 100
# Filtrer les résultats en fonction du score minimum
# Le min_score est entre 0 et 1, mais similarity est en pourcentage (0-100)
min_score_percent = min_score * 100 if min_score is not None else 0
if min_score is not None and similarity < min_score_percent:
logfire.debug(f"Document filtré (score {similarity:.2f}% < minimum {min_score_percent:.2f}%)")
continue
results.append({
"score": similarity, # Score de similarité en pourcentage
"raw_score": raw_score, # Score brut pour débogage
"text": chunk["text"],
"metadata": chunk["metadata"] # Contient source, category, chunk_id_in_doc, start_index etc.
})
else:
logfire.warn(f"Index Faiss {idx} hors limites (taille des chunks: {len(self.document_chunks)}).")
# Trier par score (similarité la plus élevée en premier)
results.sort(key=lambda x: x["score"], reverse=True)
# Limiter au nombre demandé (k) si nécessaire
if len(results) > k:
results = results[:k]
if min_score is not None:
min_score_percent = min_score * 100
logfire.info(f"{len(results)} chunks pertinents trouvés (score minimum: {min_score_percent:.2f}%).")
else:
logfire.info(f"{len(results)} chunks pertinents trouvés.")
return results
except SDKError as e:
logfire.error(f"Erreur API Mistral lors de la génération de l'embedding de la requête: {e}")
logfire.error(f" Détails: Status Code={e.status_code}, Body={e.body}")
return []
except Exception as e:
logfire.error(f"Erreur inattendue lors de la recherche: {e}")
return [] |