RAG_architectures / pinecone_utilsB.py
Aidahaouas's picture
embedding model Updated
fd1c27c
import streamlit as st
from config import sparse_index as indexB
from config import *
import nltk
import zlib
import base64
import json
import hashlib
import uuid
nltk.download('punkt_tab')
# Initialiser l'état de session pour Streamlit
if "bm25_corpus" not in st.session_state:
st.session_state.bm25_corpus = []
if "indexing_done" not in st.session_state:
st.session_state.indexing_done = False
def normalize_text(text):
"""Normalise le texte en supprimant les espaces superflus et les sauts de ligne."""
return " ".join(text.split())
def get_text_hash(text):
"""Retourne un hash du texte normalisé."""
normalized_text = normalize_text(text)
return hashlib.md5(normalized_text.encode("utf-8")).hexdigest()
def generate_unique_id():
"""Génère un ID unique."""
return str(uuid.uuid4())
def get_existing_vectors(index):
"""Récupère les hashs des textes déjà indexés dans Pinecone."""
existing_hashes = set()
try:
results = index.query(vector=[0] * 1024, top_k=10000, include_metadata=True)
for match in results.get("matches", []):
if "metadata" in match and "compressed_text" in match["metadata"]:
compressed_text = match["metadata"]["compressed_text"]
existing_hashes.add(get_text_hash(decompress_text(compressed_text)))
except Exception as e:
st.error(f"Erreur lors de la récupération des vecteurs existants : {e}")
return existing_hashes
def index_pdf_B(texts):
"""Indexe les textes en évitant les doublons."""
if not texts:
st.error("La liste des textes ne peut pas être vide.")
return
# Vérifier si l'indexation a déjà été effectuée
if st.session_state.indexing_done:
st.warning("L'indexation a déjà été effectuée. Ignorer.")
return
st.write("Indexation en cours, veuillez patienter...")
# Récupérer les textes déjà indexés
existing_hashes = get_existing_vectors(indexB)
# Initialiser BM25
st.session_state.bm25_corpus = texts
sparse_encoder.fit(texts)
for i, text in enumerate(texts):
chunks = split_text_into_chunks(text, max_chunk_size=1024)
for j, chunk in enumerate(chunks):
# Normaliser le chunk et calculer son hash
normalized_chunk = normalize_text(chunk)
chunk_hash = get_text_hash(normalized_chunk)
# Vérifier si ce texte est déjà indexé
if chunk_hash in existing_hashes:
continue # Ignorer ce document car il est déjà indexé
# Générer les valeurs sparse avec BM25
sparse_values = sparse_encoder.encode_documents([chunk])[0]
# Sérialiser les valeurs sparse en JSON
sparse_values_json = json.dumps(sparse_values)
# Créer les métadonnées avec le texte compressé et les valeurs sparse sérialisées
metadata = {
"compressed_text": compress_text(normalized_chunk),
"sparse_values": sparse_values_json # Utiliser la version sérialisée
}
# Vérifier la taille des métadonnées
metadata_size = get_metadata_size(metadata)
if metadata_size > 40960: # 40 KB
print(f"Attention : la taille des métadonnées ({metadata_size} bytes) dépasse la limite de 40960 bytes.")
continue
# Générer le vecteur dense avec SentenceTransformer
vector = model.encode([chunk]).tolist()[0]
# Générer un ID unique pour le vecteur
vector_id = generate_unique_id()
# Indexer le vecteur avec les métadonnées
indexB.upsert([(vector_id, vector, metadata)])
# Ajouter le hash du texte à la liste des textes indexés
existing_hashes.add(chunk_hash)
st.session_state.indexing_done = True # Marquer l'indexation comme terminée
st.success("Indexation terminée sans duplication de contenu.")
def hybrid_search(query, alpha, k, similarity_threshold):
"""Récupère les documents pertinents en combinant les résultats de Pinecone et BM25."""
try:
# Générer le vecteur dense pour la requête
query_vector = model.encode([query]).tolist()[0]
# Générer les valeurs sparse pour la requête
sparse_query = sparse_encoder.encode_queries([query])[0]
# Effectuer une recherche hybride dans l'indexB
results = indexB.query(
vector=query_vector, # Vecteur dense
sparse_vector=sparse_query, # Valeurs sparse
top_k=k, # Nombre de résultats à retourner
include_metadata=True, # Inclure les métadonnées
alpha=alpha
)
# Récupérer les documents pertinents
relevant_docs = []
total_words = 0
total_tokens = 0
for match in results.get("matches", []):
if "metadata" in match and "compressed_text" in match["metadata"]:
score = match.get("score", 0) # Score de similarité
if score >= similarity_threshold: # Filtrer par seuil
compressed_text = match["metadata"]["compressed_text"]
sparse_values_json = match["metadata"].get("sparse_values")
# Désérialiser les valeurs sparse si elles existent
sparse_values = json.loads(sparse_values_json) if sparse_values_json else None
# Décompression du texte
text = decompress_text(compressed_text)
relevant_docs.append({
"text": text,
"sparse_values": sparse_values,
"score": score
})
# Calcul du nombre de mots et de tokens
total_words += len(text.split()) # Nombre de mots (séparés par des espaces)
total_tokens += len(model.tokenizer.encode(text)) # Nombre de tokens
else:
print(f"Skipping match due to missing metadata or compressed_text: {match}")
# Calcul des moyennes
num_docs = len(relevant_docs)
avg_words_per_doc = total_words / num_docs if num_docs > 0 else 0
avg_tokens_per_doc = total_tokens / num_docs if num_docs > 0 else 0
print(f"Nombre de documents récupérés : {num_docs}")
print(f"Moyenne de mots par document : {avg_words_per_doc:.2f}")
print(f"Moyenne de tokens par document : {avg_tokens_per_doc:.2f}")
return relevant_docs
except Exception as e:
st.error(f"Erreur lors de la recherche hybride : {e}")
return []
def compress_text(text):
"""Compresse un texte en base64."""
compressed = zlib.compress(text.encode("utf-8"))
return base64.b64encode(compressed).decode("utf-8")
def decompress_text(compressed_text):
"""Décompresse un texte compressé en base64."""
try:
compressed_data = base64.b64decode(compressed_text.encode("utf-8"))
return zlib.decompress(compressed_data).decode("utf-8")
except Exception as e:
st.error(f"Erreur de décompression : {e}")
return ""
def split_text_into_chunks(text, max_chunk_size=1024):
"""Divise un texte en morceaux de taille maximale `max_chunk_size`."""
return [text[i:i+max_chunk_size] for i in range(0, len(text), max_chunk_size)]
def get_metadata_size(metadata):
"""Calcule la taille des métadonnées en octets."""
return len(str(metadata).encode("utf-8"))