import os # Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être # définie avant d’importer datasets ou transformers, sinon elle sera ignorée. cache_dir = "/tmp" os.makedirs(cache_dir, exist_ok=True) # Rediriger le cache HF globalement os.environ["HF_HOME"] = cache_dir os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets") os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers") from typing import List, Dict, Any import duckdb import faiss import pandas as pd from huggingface_hub import hf_hub_download from sentence_transformers import SentenceTransformer, CrossEncoder import torch from datasets import load_dataset from dotenv import load_dotenv import pyarrow as pa import pyarrow.compute as pc import logging logging.basicConfig(level=logging.DEBUG) # Initialisations load_dotenv() HF_TOKEN = os.getenv('API_HF_TOKEN') REPO_ID = "Loren/articles_database" FAISS_REPO_ID = "Loren/articles_faiss" FAISS_INDEX_FILE = "faiss_index.bin" MODEL_NAME = "intfloat/multilingual-e5-small" #CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2" CROSS_ENCODER_NAME = "Alibaba-NLP/gte-multilingual-reranker-base" revision = "a6258e9d2b1a11aa7bccdff9efde562bbca4393d" #ROSS_ENCODER_NAME = "jinaai/jina-reranker-v2-base-multilingual" # Téléchargement des fichiers Parquet depuis Hugging Face articles_parquet = hf_hub_download( repo_id=REPO_ID, filename="articles_checked.parquet", repo_type="dataset", cache_dir=cache_dir) tags_parquet = hf_hub_download( repo_id=REPO_ID, filename="tags.parquet", repo_type="dataset", cache_dir=cache_dir) tag_article_parquet = hf_hub_download( repo_id=REPO_ID, filename="tag_article.parquet", repo_type="dataset", cache_dir=cache_dir) # Connexion DuckDB en mémoire con = duckdb.connect() # Créer des tables DuckDB directement à partir des fichiers Parquet con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')") con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')") con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')") # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face hf_faiss_index = hf_hub_download( repo_id=FAISS_REPO_ID, filename=FAISS_INDEX_FILE, repo_type="dataset", token=HF_TOKEN, cache_dir=cache_dir ) # Chargement de l’index FAISS faiss_index = faiss.read_index(hf_faiss_index) # Téléchargement des metadatas Faiss depuis le dataset Hugging Face dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN) arrow_table = dataset.data # Creation du Sentence transformer model device = "cuda" if torch.cuda.is_available() else "cpu" print(f"*** Device: {device}") model = SentenceTransformer(MODEL_NAME, device=device) # Création du cross-encoder cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device, revision=revision, trust_remote_code=True) # Fonctions d'accès aux données def fetch_tags() -> List[str]: """ Récupère la liste de tous les tags disponibles dans la base de données. Returns: Dict: Un dictionnaire contenant le statut et les résultats. - Si succès : { "status": "ok", "result": List[str] # Liste des noms de tags triés par ordre alphabétique } - En cas d'erreur : { "status": "error", "code": str, # Nom de l'exception "message": str # Message de l'exception } """ try: query = "SELECT tag_name FROM tags ORDER BY tag_name" result = con.execute(query).fetchall() return {"status": "ok", "result": [row[0] for row in result]} except Exception as e: return {"status": "error", "code": type(e).__name__, "message": str(e)} def fetch_articles_by_tags(tags: List[str]) -> List[Dict]: """ Récupère les articles associés à un ou plusieurs tags. Args: tags (List[str]): Une liste de noms de tags pour filtrer les articles. Returns: Dict: Un dictionnaire contenant le statut et les résultats. - Si succès : { "status": "ok", "result": List[Dict] # Liste de dictionnaires représentant les articles } Chaque dictionnaire contient les clés : - 'article_id': int, ID de l'article - 'article_title': str, Titre de l'article - 'article_url': str, URL de l'article - En cas d'erreur ou si aucun tag fourni : { "status": "error", "code": str, # Code d'erreur ou nom de l'exception "message": str # Message d'erreur } Notes: - Si la liste `tags` est vide, la fonction retourne une liste vide. - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis. """ if not tags: return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."} try: placeholders = ",".join(["?"] * len(tags)) query = f"""SELECT distinct a.article_id, a.article_title, a.article_url, CASE WHEN a.article_online THEN a.article_url ELSE 'Article unavailable' END AS url, FROM tags t, tag_article ta, articles a WHERE t.tag_id = ta.tag_id AND ta.article_id = a.article_id AND t.tag_name IN ({placeholders}) """ result = con.execute(query, tags).fetchdf() return {"status": "ok", "result": result.to_dict(orient="records")} except Exception as e: return {"status": "error", "code": type(e).__name__, "message": str(e)} def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5, use_rerank: bool = True ) -> Dict[str, Any]: """ Exécute une requête de recherche sémantique avec FAISS, puis (optionnellement) rerank avec un cross-encoder et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB. Paramètres ---------- query : str La requête texte fournie par l'utilisateur. k_model : int, optionnel (défaut = 10) Nombre de résultats les plus proches à récupérer depuis l'index FAISS. k_cross : int, optionnel (défaut = 5) Nombre de résultats finaux à conserver après reranking avec le cross-encoder. use_rerank : bool, optionnel (défaut = True) Si False, on désactive complètement le cross-encoder et le rerank. Retour ------ Dict[str, Any] Un dictionnaire contenant : - status : "ok" si succès, sinon "error" - result : liste de résultats (si succès) - code et message : informations d'erreur (si échec) """ if not query: return {"status": "error", "code": "no_query", "message": "Aucun query fourni."} try: query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True) distances, indices = faiss_index.search(query_vec, k_model) # Résultats FAISS faiss_ids_list = indices[0].tolist() distances_list = distances[0].tolist() # Filtrer Arrow sur les IDs trouvés filtered_table = arrow_table.filter( pc.is_in(arrow_table['faiss_id'], value_set=pa.array(faiss_ids_list)) ) # Convertir Arrow → pandas pour ajouter la distance df = filtered_table.to_pandas() # Ajouter la distance en gardant l'ordre faiss_ids_list distance_map = dict(zip(faiss_ids_list, distances_list)) df["distance"] = df["faiss_id"].map(distance_map) if use_rerank: # Cross-encoder df["chunk_text"] = df["chunk_text"].str.replace(r'\s+', ' ', regex=True).str.strip() top_passages = df["chunk_text"].tolist() cross_input = [(query, p) for p in top_passages] cross_scores = cross_encoder.predict(cross_input) # Rerank df["cross_score"] = cross_scores df = df.sort_values(by="cross_score", ascending=False) # Garder top k_cross df_top = df.head(k_cross) else: df = df.sort_values(by="distance", ascending=False) df["cross_score"] = df["distance"] # Garder top k_model df_top = df.head(k_model) # Enregistrer dans DuckDB con.register("faiss_tmp", df_top) sql = """ SELECT f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text, a.article_title, a.article_url, CASE WHEN a.article_online THEN a.article_url ELSE 'Article unavailable' END AS url, STRING_AGG(t.tag_name, ', ') AS tags FROM faiss_tmp f JOIN articles a ON f.document_id = a.article_id JOIN tag_article ta ON a.article_id = ta.article_id JOIN tags t ON ta.tag_id = t.tag_id WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100 GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text, a.article_title, a.article_online, a.article_url ORDER BY AVG(f.cross_score) DESC """ duck_res = con.execute(sql).fetchdf() # Liste finale de dictionnaires list_result = duck_res.to_dict(orient="records") return {"status": "ok", "result": list_result} except Exception as e: return {"status": "error", "code": type(e).__name__, "message": str(e)}