Spaces:
Sleeping
Sleeping
File size: 8,444 Bytes
5327a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
##############################################################################################
### Script de création de la base de données FAISS des articles
###
### Ce script
### - charge la table articles depuis le dataset HF Loren/articles_database
### - la traite par batch :
### - création de chunks de texte
### - création des embeddings avec le modèle SentenceTransformer "intfloat/e5-small"
### - ajout des embeddings dans un index FAISS
### - sauvegarde des métadonnées des chunks dans un fichier parquet
### - sauvegarde de l'index FAISS dans un fichier faiss_index.bin
### - upload dans le dataset HF Loren/articles_faiss
###
### 👉 L'index Faiss peut alors être utilisé par un space Hugging Face
##############################################################################################
import os
import torch
import duckdb
from huggingface_hub import hf_hub_download, upload_file
from huggingface_hub import HfApi, HfFolder, CommitOperationAdd
import faiss
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from functools import partial
import pyarrow as pa
import pyarrow.parquet as pq
from pathlib import Path
from dotenv import load_dotenv
# Fonctions
# Batch processing function
def batch_process(list_articles: list, faiss_id_start: int) -> int:
"""
Traite un batch d'articles pour générer des embeddings et des métadonnées,
puis les sauvegarde de manière sécurisée pour garantir la persistance en cas de problème.
Étapes réalisées :
1. Découpage de chaque article en chunks via le splitter.
2. Création d'un dictionnaire de métadonnées pour chaque chunk contenant :
- faiss_id : identifiant unique aligné avec l'index FAISS
- document_id : identifiant de l'article
- chunk_text : texte du chunk
3. Calcul des embeddings pour tous les chunks du batch.
4. Ajout des embeddings au FAISS index existant (append).
5. Écriture immédiate de l'index FAISS sur disque pour assurer la persistance.
6. Sauvegarde des métadonnées batch dans un fichier Parquet distinct.
Args:
list_articles (list): Liste de tuples (document_id, document_text) représentant les articles du batch.
faiss_id_start (int): Identifiant de départ pour le premier chunk du batch,
utilisé pour aligner FAISS et les métadonnées.
Returns:
int: Identifiant FAISS suivant, à utiliser pour le batch suivant afin de maintenir l'alignement.
Notes :
- Cette fonction est conçue pour être utilisée batch par batch.
- Les fichiers Parquet et le fichier FAISS sont mis à jour à chaque batch pour éviter toute perte de données.
"""
global faiss_index
try:
list_chunks = []
list_metadata = []
for doc_id, doc_content in list_articles:
chunks = splitter.split_text(doc_content)
for chunk_text in chunks:
list_chunks.append(chunk_text)
list_metadata.append({
"faiss_id": faiss_id_start,
"document_id": doc_id,
"chunk_text": chunk_text
})
faiss_id_start += 1
# Embeddings
if list_chunks:
passage_texts = [f"passage: {p}" for p in list_chunks]
embeddings = model.encode(passage_texts, convert_to_numpy=True,
normalize_embeddings=True)
faiss_index.add(embeddings)
faiss.write_index(faiss_index, str(FAISS_INDEX_FILE))
# Sauvegarde batch métadonnées en Parquet
if list_metadata:
table = pa.Table.from_pylist(list_metadata)
batch_file = PARQUET_DIR / f"metadata_batch_{faiss_id_start}.parquet"
pq.write_table(table, batch_file)
return faiss_id_start
except Exception as e:
print(f"ERROR in batch_process function : {e}")
return None
##
# Initialisations
global faiss_index
print("Initialisations ...")
load_dotenv()
HF_TOKEN = os.getenv('API_HF_TOKEN')
REPO_ID = "Loren/articles_database"
DATA_DIR = Path("../../Data") # dossier parent du script
CHUNK_SIZE = 250
CHUNK_OVERLAP = 50
BATCH_SIZE = 1000
MODEL_NAME = "intfloat/multilingual-e5-small"
FAISS_INDEX_FILE = DATA_DIR / "faiss_index.bin"
PARQUET_DIR = DATA_DIR / "parquet_metadata"
CACHE_DIR = "/tmp"
os.makedirs(CACHE_DIR, exist_ok=True)
# Rediriger le cache HF globalement
os.environ["HF_HOME"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
# Téléchargement des fichiers Parquet depuis Hugging Face
print("Téléchargement des fichiers Parquet depuis Hugging Face ...")
articles_parquet = hf_hub_download(
repo_id=REPO_ID,
filename="articles.parquet",
repo_type="dataset",
cache_dir=CACHE_DIR)
# Connexion DuckDB en mémoire
con = duckdb.connect()
# Créer des tables DuckDB directement à partir des fichiers Parquet
print("Création des vues DuckDB à partir des fichiers Parquet ...")
con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
# Creating the plitter for chunking document
print("Initialisation du text splitter ...")
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
keep_separator='end',
separators=["\n\n", "\n", "."]
)
# Creating the Sentence transformer model
print("Initialisation du modèle de Sentence Transformer ...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"*** Device: {device}")
model = SentenceTransformer(MODEL_NAME, device=device)
# Creating the Faiss index
embedding_dim = model.get_sentence_embedding_dimension()
faiss_index = faiss.IndexFlatIP(embedding_dim)
faiss_id_counter = 0 # compteur global pour lier faiss_id et métadonnées
# Traitement par batchs
print("Création des batches et traitement ...")
cursor = con.execute("""
SELECT article_id, article_text
FROM articles
WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100""")
# Création d'un itérateur de batches
fetch_batch = partial(cursor.fetchmany, BATCH_SIZE)
for batch_num, batch in enumerate(iter(fetch_batch, []), start=1):
print("Traitement batch no ", batch_num, " ...")
faiss_id_counter = batch_process(batch, faiss_id_counter)
if not faiss_id_counter:
print("*** Erreur traitement batch no ", batch_num)
print("\n✅ Traitement terminé")
# Upload des fichiers vers HF
# Création du dataset HF
REPO_ID = "Loren/articles_faiss"
api = HfApi()
HfFolder.save_token(HF_TOKEN)
# Vérifier si le dataset existe
try:
repo_info = api.dataset_info(REPO_ID, token=HF_TOKEN)
print(f"Dataset {REPO_ID} existe déjà, suppression en cours...")
api.delete_repo(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN)
except Exception as e:
print(f"Dataset n'existait pas : {e}")
# Créer le dataset (privé)
api.create_repo(repo_id=REPO_ID, repo_type="dataset", exist_ok=True, private=True, token=HF_TOKEN)
print(f"Dataset {REPO_ID} créé avec succès.")
# Récupérer la liste de fichiers parquet
print("Upload des fichiers metadatas dans le dataset hugging face ", REPO_ID, " ...")
parquet_files = [
os.path.join(PARQUET_DIR, f)
for f in os.listdir(PARQUET_DIR)
if f.endswith(".parquet")
]
# Ajouter tous les fichiers
operations = [
CommitOperationAdd(
path_in_repo=f"data/{os.path.basename(f)}",
path_or_fileobj=f
)
for f in parquet_files
]
api.create_commit(
repo_id=REPO_ID,
repo_type="dataset",
operations=operations,
commit_message="Upload batch metadata parquet files"
)
print("✅ Upload metadatas terminé !")
print("Upload de l'index Faiss dans le dataset hugging face ", REPO_ID, " ...")
upload_file(
path_or_fileobj=FAISS_INDEX_FILE,
path_in_repo=FAISS_INDEX_FILE.name,
repo_id=REPO_ID,
repo_type="dataset",
token=HF_TOKEN
)
print("✅ Upload faiss index terminé")
con.close()
print("✅ Traitement terminé") |