from datasets import load_dataset from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings import os import time # Paramètres CHUNK_SIZE = 500 CHUNK_OVERLAP = 100 DB_PATH = os.path.abspath("../../db") # Chemin racine du projet EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" print("[INFO] Chargement du dataset Code du Travail...") dataset = load_dataset("louisbrulenaudet/code-travail")['train'] print(f"[INFO] {len(dataset)} articles récupérés.") splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) documents = [] metadatas = [] for idx, item in enumerate(dataset): if idx == 0: print(f"[DEBUG] Clés disponibles dans un item : {list(item.keys())}") ref = item.get("ref", "(sans ref)") texte = item.get("texte", "") if not texte: print(f"[WARN] Article sans texte, index {idx}, ref : {ref}") continue for chunk in splitter.split_text(texte): documents.append(chunk) metadatas.append({ "source": "textes_loi", # Pour cohérence avec le filtre de l'assistant "ref": ref, "num": item.get("num"), "dateDebut": item.get("dateDebut"), "dateFin": item.get("dateFin"), "etat": item.get("etat"), "type": item.get("type"), "id": item.get("id"), }) print(f"[INFO] {len(documents)} chunks générés à partir du dataset.") print(f"[INFO] Chargement des embeddings ({EMBEDDING_MODEL})...") embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) # Charger ou créer la base Chroma existante if os.path.exists(DB_PATH): print(f"[INFO] Ouverture de la base vectorielle existante : {DB_PATH}") db = Chroma(persist_directory=DB_PATH, embedding_function=embeddings) else: print(f"[INFO] Création d'une nouvelle base vectorielle : {DB_PATH}") os.makedirs(DB_PATH, exist_ok=True) db = Chroma(persist_directory=DB_PATH, embedding_function=embeddings) # Ajout des nouveaux documents par batch BATCH_SIZE = 5000 # inférieur à la limite Chroma print("[INFO] Ajout des nouveaux documents à la base vectorielle (par batch)...") t0 = time.time() for i in range(0, len(documents), BATCH_SIZE): db.add_texts( documents[i:i+BATCH_SIZE], metadatas=metadatas[i:i+BATCH_SIZE] ) print(f"[INFO] Batch {i//BATCH_SIZE + 1} ajouté ({min(i+BATCH_SIZE, len(documents))}/{len(documents)})") db.persist() t1 = time.time() print(f"[SUCCESS] {len(documents)} chunks ajoutés à la base vectorielle en {t1-t0:.1f} secondes.") # Affichage du total de documents dans la base try: total_docs = db._collection.count() print(f"[INFO] Total de documents dans la base vectorielle après ajout : {total_docs}") except Exception as e: print(f"[WARN] Impossible de compter le nombre total de documents : {e}") print(f"[INFO] La base vectorielle est prête dans : {DB_PATH}")