oc_mlops_projet_3 / utils /vector_store.py
CedM's picture
Déploiement automatique depuis GitLab CI
7cb1544 verified
# utils/vector_store.py
import os
import time
import pickle
import faiss
import numpy as np
import logfire
from typing import List, Dict, Optional, Any
from mistralai import Mistral, SDKError
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document # Utilisé pour le format attendu par le splitter
from pydantic import BaseModel, Field, field_validator
from .config import (
MISTRAL_API_KEY, EMBEDDING_MODEL, EMBEDDING_BATCH_SIZE,
FAISS_INDEX_FILE, DOCUMENT_CHUNKS_FILE, CHUNK_SIZE, CHUNK_OVERLAP,
TEMPERATURE, TOP_P
)
# ---------------------------
# Schemas Pydantic V2
# ---------------------------
class ChunkIn(BaseModel):
"""Représente les données d'entrée pour le découpage d'un document en chunks.
Correspond à la structure des documents chargés (page_content + metadata)."""
page_content: str = Field(..., min_length=1, description="Contenu textuel du document")
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées du document (filename, source, category...)")
@field_validator('page_content')
@classmethod
def text_not_empty(cls, v):
if not v or not v.strip():
raise ValueError('Le texte du document ne peut pas être vide')
return v.strip()
class EmbeddingIn(BaseModel):
"""Représente un chunk découpé prêt pour la génération d'embeddings.
Correspond à la structure des dicts produits par _split_documents_to_chunks."""
id: str = Field(..., min_length=1, description="Identifiant unique du chunk (doc_index_chunk_index)")
text: str = Field(..., min_length=1, description="Texte du chunk")
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Métadonnées héritées + chunk_id_in_doc + start_index")
@field_validator('text')
@classmethod
def chunk_text_not_empty(cls, v):
if not v or not v.strip():
raise ValueError('Le texte du chunk ne peut pas être vide')
return v.strip()
class VectordbIn(BaseModel):
"""Représente les données d'entrée pour l'ajout d'un chunk et de son embedding dans l'index Faiss."""
chunk_id: str = Field(..., min_length=1, description="Identifiant du chunk correspondant (clé 'id' du chunk)")
vector: List[float] = Field(..., min_length=1, description="Vecteur d'embedding")
class VectorStoreManager:
"""Gère la création, le chargement et la recherche dans un index Faiss."""
def __init__(self):
self.index: Optional[faiss.Index] = None
self.document_chunks: List[Dict[str, any]] = []
self.mistral_client = Mistral(api_key=MISTRAL_API_KEY)
self._load_index_and_chunks()
def _load_index_and_chunks(self):
"""Charge l'index Faiss et les chunks si les fichiers existent."""
if os.path.exists(FAISS_INDEX_FILE) and os.path.exists(DOCUMENT_CHUNKS_FILE):
try:
logfire.info(f"Chargement de l'index Faiss depuis {FAISS_INDEX_FILE}...")
self.index = faiss.read_index(FAISS_INDEX_FILE)
logfire.info(f"Chargement des chunks depuis {DOCUMENT_CHUNKS_FILE}...")
with open(DOCUMENT_CHUNKS_FILE, 'rb') as f:
self.document_chunks = pickle.load(f)
logfire.info(f"Index ({self.index.ntotal} vecteurs) et {len(self.document_chunks)} chunks chargés.")
except Exception as e:
logfire.error(f"Erreur lors du chargement de l'index/chunks: {e}")
self.index = None
self.document_chunks = []
else:
logfire.warn("Fichiers d'index Faiss ou de chunks non trouvés. L'index est vide.")
def _split_documents_to_chunks(self, documents: List[Dict[str, any]]) -> List[Dict[str, any]]:
"""Découpe les documents en chunks avec métadonnées."""
logfire.info(f"Découpage de {len(documents)} documents/Excel Sheets en chunks (taille={CHUNK_SIZE}, chevauchement={CHUNK_OVERLAP})...")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len, # Important: mesure en caractères
add_start_index=True, # Ajoute la position de début du chunk dans le document original
)
all_chunks = []
doc_counter = 0
for doc in documents:
# Validation du document en entrée via ChunkIn
try:
validated_doc = ChunkIn(
page_content=doc["page_content"],
metadata=doc.get("metadata", {})
)
except Exception as e:
logfire.warn(f" Document ignoré (validation ChunkIn échouée): {e}")
doc_counter += 1
continue
# Convertit notre format de document en format Langchain Document pour le splitter
langchain_doc = Document(page_content=validated_doc.page_content, metadata=validated_doc.metadata)
chunks = text_splitter.split_documents([langchain_doc])
logfire.info(f" Document '{validated_doc.metadata.get('filename', 'N/A')}' découpé en {len(chunks)} chunks.")
# Enrichit chaque chunk avec des métadonnées supplémentaires
for i, chunk in enumerate(chunks):
try:
embedding_in = EmbeddingIn(
id=f"{doc_counter}_{i}",
text=chunk.page_content,
metadata={
**chunk.metadata, # Métadonnées héritées du document (source, category, etc.)
"chunk_id_in_doc": i, # Position du chunk dans son document d'origine
"start_index": chunk.metadata.get("start_index", -1) # Position de début (en caractères)
}
)
all_chunks.append(embedding_in.model_dump())
except Exception as e:
logfire.warn(f" Chunk {doc_counter}_{i} ignoré (validation EmbeddingIn échouée): {e}")
doc_counter += 1
logfire.info(f"Total de {len(all_chunks)} chunks créés.")
return all_chunks
def _generate_embeddings(self, chunks: List[Dict[str, any]]) -> Optional[np.ndarray]:
"""Génère les embeddings (tokenization + vectorisation) pour une liste de chunks via l'API Mistral."""
if not MISTRAL_API_KEY:
logfire.error("Impossible de générer les embeddings: MISTRAL_API_KEY manquante.")
return None
if not chunks:
logfire.warn("Aucun chunk fourni pour générer les embeddings.")
return None
logfire.info(f"Génération des embeddings pour {len(chunks)} chunks (modèle: {EMBEDDING_MODEL})...")
all_embeddings = []
total_batches = (len(chunks) + EMBEDDING_BATCH_SIZE - 1) // EMBEDDING_BATCH_SIZE
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
batch_num = (i // EMBEDDING_BATCH_SIZE) + 1
batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE]
texts_to_embed = [chunk["text"] for chunk in batch_chunks]
logfire.info(f" Traitement du lot {batch_num}/{total_batches} ({len(texts_to_embed)} chunks)")
try:
response = self.mistral_client.embeddings.create(
model=EMBEDDING_MODEL,
inputs=texts_to_embed
)
batch_embeddings = [data.embedding for data in response.data]
all_embeddings.extend(batch_embeddings)
except SDKError as e:
logfire.error(f"Erreur API Mistral lors de la génération d'embeddings (lot {batch_num}): {e}")
logfire.error(f" Détails: Status Code={e.status_code}, Body={e.body}")
except Exception as e:
logfire.error(f"Erreur inattendue lors de la génération d'embeddings (lot {batch_num}): {e}")
# Gérer l'erreur: ici on ajoute des vecteurs nuls pour ne pas bloquer
num_failed = len(texts_to_embed)
if all_embeddings: # Si on a déjà des embeddings, on prend la dimension du premier
dim = len(all_embeddings[0])
else: # Sinon, on ne peut pas déterminer la dimension, on saute ce lot
logfire.error("Impossible de déterminer la dimension des embeddings, saut du lot.")
continue
logfire.warn(f"Ajout de {num_failed} vecteurs nuls de dimension {dim} pour le lot échoué.")
all_embeddings.extend([np.zeros(dim, dtype='float32')] * num_failed)
if not all_embeddings:
logfire.error("Aucun embedding n'a pu être généré.")
return None
embeddings_array = np.array(all_embeddings).astype('float32')
logfire.info(f"Embeddings générés avec succès. Shape: {embeddings_array.shape}")
return embeddings_array
def build_index(self, documents: List[Dict[str, any]]):
"""Construit l'index Faiss à partir des documents."""
if not documents:
logfire.warn("Aucun document fourni pour construire l'index.")
return
# 1. Découper en chunks
self.document_chunks = self._split_documents_to_chunks(documents)
if not self.document_chunks:
logfire.error("Le découpage n'a produit aucun chunk. Impossible de construire l'index.")
return
# 2. Générer les embeddings
embeddings = self._generate_embeddings(self.document_chunks)
if embeddings is None or embeddings.shape[0] != len(self.document_chunks):
logfire.error("Problème de génération d'embeddings. Le nombre d'embeddings ne correspond pas au nombre de chunks.")
# Nettoyer pour éviter un état incohérent
self.document_chunks = []
self.index = None
# Supprimer les fichiers potentiellement corrompus
if os.path.exists(FAISS_INDEX_FILE): os.remove(FAISS_INDEX_FILE)
if os.path.exists(DOCUMENT_CHUNKS_FILE): os.remove(DOCUMENT_CHUNKS_FILE)
return
# 3. Valider chaque paire (chunk, vecteur) via VectordbIn avant l'ajout dans Faiss
valid_indices = []
for idx, (chunk, vector) in enumerate(zip(self.document_chunks, embeddings.tolist())):
try:
VectordbIn(chunk_id=chunk["id"], vector=vector)
valid_indices.append(idx)
except Exception as e:
logfire.warn(f" Chunk '{chunk.get('id')}' ignoré (validation VectordbIn échouée): {e}")
if not valid_indices:
logfire.error("Aucun vecteur valide après validation. Impossible de construire l'index.")
return
# Filtrer les chunks et embeddings valides
self.document_chunks = [self.document_chunks[i] for i in valid_indices]
embeddings = embeddings[valid_indices]
# 4. Créer l'index Faiss optimisé pour la similarité cosinus
dimension = embeddings.shape[1]
logfire.info(f"Création de l'index Faiss optimisé pour la similarité cosinus avec dimension {dimension}...")
# Normaliser les embeddings pour la similarité cosinus
faiss.normalize_L2(embeddings)
# Créer un index pour la similarité cosinus (IndexFlatIP = produit scalaire)
self.index = faiss.IndexFlatIP(dimension)
self.index.add(embeddings)
logfire.info(f"Index Faiss créé avec {self.index.ntotal} vecteurs.")
# 5. Sauvegarder l'index et les chunks
self._save_index_and_chunks()
def _save_index_and_chunks(self):
"""Sauvegarde l'index Faiss et la liste des chunks."""
if self.index is None or not self.document_chunks:
logfire.warn("Tentative de sauvegarde d'un index ou de chunks vides.")
return
os.makedirs(os.path.dirname(FAISS_INDEX_FILE), exist_ok=True)
os.makedirs(os.path.dirname(DOCUMENT_CHUNKS_FILE), exist_ok=True)
try:
logfire.info(f"Sauvegarde de l'index Faiss dans {FAISS_INDEX_FILE}...")
faiss.write_index(self.index, FAISS_INDEX_FILE)
logfire.info(f"Sauvegarde des chunks dans {DOCUMENT_CHUNKS_FILE}...")
with open(DOCUMENT_CHUNKS_FILE, 'wb') as f:
pickle.dump(self.document_chunks, f)
logfire.info("Index et chunks sauvegardés avec succès.")
except Exception as e:
logfire.error(f"Erreur lors de la sauvegarde de l'index/chunks: {e}")
def generate_rag_response(self, question: str, context_results: List[Dict[str, any]],
system_prompt_template: str, model: str,
temperature: float = TEMPERATURE
#, top_p: float = TOP_P
) -> str:
"""
Génère une réponse LLM via l'API Mistral en injectant le contexte RAG dans le prompt.
Args:
question: La question de l'utilisateur.
context_results: Liste de chunks retournés par search().
system_prompt_template: Template avec les placeholders {context_str} et {question}.
model: Nom du modèle Mistral à utiliser.
temperature: Température de génération.
top_p: Nucleus sampling (activable si besoin, désactivé par défaut via temperature).
Returns:
La réponse textuelle générée par le LLM.
"""
if context_results:
context_str = "\n\n---\n\n".join([
f"Source: {res['metadata'].get('source', 'Inconnue')} (Score: {res['score']:.1f}%)\nContenu: {res['text']}"
for res in context_results
])
else:
context_str = "Aucune information pertinente trouvée dans la base de connaissances pour cette question."
logfire.warn(f"Aucun contexte trouvé pour la question: '{question}'")
final_prompt = system_prompt_template.format(context_str=context_str, question=question)
messages = [{"role": "user", "content": final_prompt}]
try:
logfire.info(f"Appel LLM Mistral (modèle='{model}') pour la question: '{question}'")
max_retries = 6
wait_time = 10 # Attente initiale en secondes
for attempt in range(max_retries):
try:
response = self.mistral_client.chat.complete(
model=model,
messages=messages,
temperature=temperature,
# top_p=top_p, # Décommenter pour activer le nucleus sampling
)
break # Succès
except SDKError as e:
if getattr(e, "status_code", None) == 429 and attempt < max_retries - 1:
logfire.warning(
f"Rate limit 429 dans generate_rag_response (tentative {attempt + 1}/{max_retries}). "
f"Attente {wait_time}s..."
)
time.sleep(wait_time)
wait_time = min(wait_time * 2, 120) # Backoff exponentiel, max 120s
else:
raise
if response.choices:
logfire.info("Réponse LLM reçue.")
return response.choices[0].message.content
else:
logfire.warn("L'API Mistral n'a pas retourné de choix valide.")
return "Désolé, je n'ai pas pu générer de réponse valide pour le moment."
except SDKError as e:
logfire.error(f"Erreur API Mistral lors de l'appel LLM: {e}")
return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard."
except Exception as e:
logfire.error(f"Erreur inattendue lors de l'appel LLM: {e}")
return "Je suis désolé, une erreur technique m'empêche de répondre. Veuillez réessayer plus tard."
def search(self, query_text: str, k: int = 5, min_score: float = None) -> List[Dict[str, any]]:
"""
Recherche les k chunks les plus pertinents pour une requête.
Args:
query_text: Texte de la requête
k: Nombre de résultats à retourner
min_score: Score minimum (entre 0 et 1) pour inclure un résultat
Returns:
Liste des chunks pertinents avec leurs scores
"""
if self.index is None or not self.document_chunks:
logfire.warn("Recherche impossible: l'index Faiss n'est pas chargé ou est vide.")
return []
if not MISTRAL_API_KEY:
logfire.error("Recherche impossible: MISTRAL_API_KEY manquante pour générer l'embedding de la requête.")
return []
logfire.info(f"Recherche des {k} chunks les plus pertinents pour: '{query_text}'")
try:
# 1. Générer l'embedding de la requête
response = self.mistral_client.embeddings.create(
model=EMBEDDING_MODEL,
inputs=[query_text]
)
query_embedding = np.array([response.data[0].embedding]).astype('float32')
# Normaliser l'embedding de la requête pour la similarité cosinus
faiss.normalize_L2(query_embedding)
# 2. Rechercher dans l'index Faiss
# Pour IndexFlatIP: scores = produit scalaire (plus grand = meilleur)
# indices: index des chunks correspondants dans self.document_chunks
# Demander plus de résultats si un score minimum est spécifié
search_k = k * 3 if min_score is not None else k
scores, indices = self.index.search(query_embedding, search_k)
# 3. Formater les résultats
results = []
if indices.size > 0: # Vérifier s'il y a des résultats
for i, idx in enumerate(indices[0]):
if 0 <= idx < len(self.document_chunks): # Vérifier la validité de l'index
chunk = self.document_chunks[idx]
# Convertir le score en similarité (0-1)
# Pour IndexFlatIP avec vecteurs normalisés, le score est déjà entre -1 et 1
# On le convertit en pourcentage (0-100%)
raw_score = float(scores[0][i])
similarity = raw_score * 100
# Filtrer les résultats en fonction du score minimum
# Le min_score est entre 0 et 1, mais similarity est en pourcentage (0-100)
min_score_percent = min_score * 100 if min_score is not None else 0
if min_score is not None and similarity < min_score_percent:
logfire.debug(f"Document filtré (score {similarity:.2f}% < minimum {min_score_percent:.2f}%)")
continue
results.append({
"score": similarity, # Score de similarité en pourcentage
"raw_score": raw_score, # Score brut pour débogage
"text": chunk["text"],
"metadata": chunk["metadata"] # Contient source, category, chunk_id_in_doc, start_index etc.
})
else:
logfire.warn(f"Index Faiss {idx} hors limites (taille des chunks: {len(self.document_chunks)}).")
# Trier par score (similarité la plus élevée en premier)
results.sort(key=lambda x: x["score"], reverse=True)
# Limiter au nombre demandé (k) si nécessaire
if len(results) > k:
results = results[:k]
if min_score is not None:
min_score_percent = min_score * 100
logfire.info(f"{len(results)} chunks pertinents trouvés (score minimum: {min_score_percent:.2f}%).")
else:
logfire.info(f"{len(results)} chunks pertinents trouvés.")
return results
except SDKError as e:
logfire.error(f"Erreur API Mistral lors de la génération de l'embedding de la requête: {e}")
logfire.error(f" Détails: Status Code={e.status_code}, Body={e.body}")
return []
except Exception as e:
logfire.error(f"Erreur inattendue lors de la recherche: {e}")
return []