Spaces:
Sleeping
Sleeping
| import os | |
| # Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être | |
| # définie avant d’importer datasets ou transformers, sinon elle sera ignorée. | |
| cache_dir = "/tmp" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Rediriger le cache HF globalement | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets") | |
| os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers") | |
| from typing import List, Dict, Any | |
| import duckdb | |
| import faiss | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import torch | |
| from datasets import load_dataset | |
| from dotenv import load_dotenv | |
| import pyarrow as pa | |
| import pyarrow.compute as pc | |
| import logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| # Initialisations | |
| load_dotenv() | |
| HF_TOKEN = os.getenv('API_HF_TOKEN') | |
| REPO_ID = "Loren/articles_database" | |
| FAISS_REPO_ID = "Loren/articles_faiss" | |
| FAISS_INDEX_FILE = "faiss_index.bin" | |
| MODEL_NAME = "intfloat/multilingual-e5-small" | |
| #CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2" | |
| CROSS_ENCODER_NAME = "Alibaba-NLP/gte-multilingual-reranker-base" | |
| revision = "a6258e9d2b1a11aa7bccdff9efde562bbca4393d" | |
| #ROSS_ENCODER_NAME = "jinaai/jina-reranker-v2-base-multilingual" | |
| # Téléchargement des fichiers Parquet depuis Hugging Face | |
| articles_parquet = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="articles_checked.parquet", | |
| repo_type="dataset", | |
| cache_dir=cache_dir) | |
| tags_parquet = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="tags.parquet", | |
| repo_type="dataset", | |
| cache_dir=cache_dir) | |
| tag_article_parquet = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="tag_article.parquet", | |
| repo_type="dataset", | |
| cache_dir=cache_dir) | |
| # Connexion DuckDB en mémoire | |
| con = duckdb.connect() | |
| # Créer des tables DuckDB directement à partir des fichiers Parquet | |
| con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')") | |
| con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')") | |
| con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')") | |
| # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face | |
| hf_faiss_index = hf_hub_download( | |
| repo_id=FAISS_REPO_ID, | |
| filename=FAISS_INDEX_FILE, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| cache_dir=cache_dir | |
| ) | |
| # Chargement de l’index FAISS | |
| faiss_index = faiss.read_index(hf_faiss_index) | |
| # Téléchargement des metadatas Faiss depuis le dataset Hugging Face | |
| dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN) | |
| arrow_table = dataset.data | |
| # Creation du Sentence transformer model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"*** Device: {device}") | |
| model = SentenceTransformer(MODEL_NAME, device=device) | |
| # Création du cross-encoder | |
| cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device, | |
| revision=revision, | |
| trust_remote_code=True) | |
| # Fonctions d'accès aux données | |
| def fetch_tags() -> List[str]: | |
| """ | |
| Récupère la liste de tous les tags disponibles dans la base de données. | |
| Returns: | |
| Dict: Un dictionnaire contenant le statut et les résultats. | |
| - Si succès : | |
| { | |
| "status": "ok", | |
| "result": List[str] # Liste des noms de tags triés par ordre alphabétique | |
| } | |
| - En cas d'erreur : | |
| { | |
| "status": "error", | |
| "code": str, # Nom de l'exception | |
| "message": str # Message de l'exception | |
| } | |
| """ | |
| try: | |
| query = "SELECT tag_name FROM tags ORDER BY tag_name" | |
| result = con.execute(query).fetchall() | |
| return {"status": "ok", "result": [row[0] for row in result]} | |
| except Exception as e: | |
| return {"status": "error", "code": type(e).__name__, "message": str(e)} | |
| def fetch_articles_by_tags(tags: List[str]) -> List[Dict]: | |
| """ | |
| Récupère les articles associés à un ou plusieurs tags. | |
| Args: | |
| tags (List[str]): Une liste de noms de tags pour filtrer les articles. | |
| Returns: | |
| Dict: Un dictionnaire contenant le statut et les résultats. | |
| - Si succès : | |
| { | |
| "status": "ok", | |
| "result": List[Dict] # Liste de dictionnaires représentant les articles | |
| } | |
| Chaque dictionnaire contient les clés : | |
| - 'article_id': int, ID de l'article | |
| - 'article_title': str, Titre de l'article | |
| - 'article_url': str, URL de l'article | |
| - En cas d'erreur ou si aucun tag fourni : | |
| { | |
| "status": "error", | |
| "code": str, # Code d'erreur ou nom de l'exception | |
| "message": str # Message d'erreur | |
| } | |
| Notes: | |
| - Si la liste `tags` est vide, la fonction retourne une liste vide. | |
| - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis. | |
| """ | |
| if not tags: | |
| return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."} | |
| try: | |
| placeholders = ",".join(["?"] * len(tags)) | |
| query = f"""SELECT distinct a.article_id, a.article_title, a.article_url, | |
| CASE WHEN a.article_online | |
| THEN a.article_url | |
| ELSE 'Article unavailable' END AS url, | |
| FROM tags t, tag_article ta, articles a | |
| WHERE t.tag_id = ta.tag_id | |
| AND ta.article_id = a.article_id | |
| AND t.tag_name IN ({placeholders}) | |
| """ | |
| result = con.execute(query, tags).fetchdf() | |
| return {"status": "ok", "result": result.to_dict(orient="records")} | |
| except Exception as e: | |
| return {"status": "error", "code": type(e).__name__, "message": str(e)} | |
| def fetch_query_results(query: str, k_model: int = 10, | |
| k_cross: int = 5, use_rerank: bool = True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Exécute une requête de recherche sémantique avec FAISS, puis (optionnellement) | |
| rerank avec un cross-encoder et retourne les meilleurs passages enrichis avec | |
| des métadonnées provenant de DuckDB. | |
| Paramètres | |
| ---------- | |
| query : str | |
| La requête texte fournie par l'utilisateur. | |
| k_model : int, optionnel (défaut = 10) | |
| Nombre de résultats les plus proches à récupérer depuis l'index FAISS. | |
| k_cross : int, optionnel (défaut = 5) | |
| Nombre de résultats finaux à conserver après reranking avec le cross-encoder. | |
| use_rerank : bool, optionnel (défaut = True) | |
| Si False, on désactive complètement le cross-encoder et le rerank. | |
| Retour | |
| ------ | |
| Dict[str, Any] | |
| Un dictionnaire contenant : | |
| - status : "ok" si succès, sinon "error" | |
| - result : liste de résultats (si succès) | |
| - code et message : informations d'erreur (si échec) | |
| """ | |
| if not query: | |
| return {"status": "error", "code": "no_query", "message": "Aucun query fourni."} | |
| try: | |
| query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True) | |
| distances, indices = faiss_index.search(query_vec, k_model) | |
| # Résultats FAISS | |
| faiss_ids_list = indices[0].tolist() | |
| distances_list = distances[0].tolist() | |
| # Filtrer Arrow sur les IDs trouvés | |
| filtered_table = arrow_table.filter( | |
| pc.is_in(arrow_table['faiss_id'], | |
| value_set=pa.array(faiss_ids_list)) | |
| ) | |
| # Convertir Arrow → pandas pour ajouter la distance | |
| df = filtered_table.to_pandas() | |
| # Ajouter la distance en gardant l'ordre faiss_ids_list | |
| distance_map = dict(zip(faiss_ids_list, distances_list)) | |
| df["distance"] = df["faiss_id"].map(distance_map) | |
| if use_rerank: | |
| # Cross-encoder | |
| df["chunk_text"] = df["chunk_text"].str.replace(r'\s+', ' ', regex=True).str.strip() | |
| top_passages = df["chunk_text"].tolist() | |
| cross_input = [(query, p) for p in top_passages] | |
| cross_scores = cross_encoder.predict(cross_input) | |
| # Rerank | |
| df["cross_score"] = cross_scores | |
| df = df.sort_values(by="cross_score", ascending=False) | |
| # Garder top k_cross | |
| df_top = df.head(k_cross) | |
| else: | |
| df = df.sort_values(by="distance", ascending=False) | |
| df["cross_score"] = df["distance"] | |
| # Garder top k_model | |
| df_top = df.head(k_model) | |
| # Enregistrer dans DuckDB | |
| con.register("faiss_tmp", df_top) | |
| sql = """ | |
| SELECT | |
| f.faiss_id, | |
| f.document_id, | |
| f.distance, | |
| f.cross_score, | |
| f.chunk_text, | |
| a.article_title, | |
| a.article_url, | |
| CASE WHEN a.article_online | |
| THEN a.article_url | |
| ELSE 'Article unavailable' END AS url, | |
| STRING_AGG(t.tag_name, ', ') AS tags | |
| FROM faiss_tmp f | |
| JOIN articles a ON f.document_id = a.article_id | |
| JOIN tag_article ta ON a.article_id = ta.article_id | |
| JOIN tags t ON ta.tag_id = t.tag_id | |
| WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100 | |
| GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text, | |
| a.article_title, a.article_online, a.article_url | |
| ORDER BY AVG(f.cross_score) DESC | |
| """ | |
| duck_res = con.execute(sql).fetchdf() | |
| # Liste finale de dictionnaires | |
| list_result = duck_res.to_dict(orient="records") | |
| return {"status": "ok", "result": list_result} | |
| except Exception as e: | |
| return {"status": "error", "code": type(e).__name__, "message": str(e)} | |