api_search_articles / app /database.py
Loren's picture
Upload database.py
e1182e0 verified
import os
# Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être
# définie avant d’importer datasets ou transformers, sinon elle sera ignorée.
cache_dir = "/tmp"
os.makedirs(cache_dir, exist_ok=True)
# Rediriger le cache HF globalement
os.environ["HF_HOME"] = cache_dir
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers")
from typing import List, Dict, Any
import duckdb
import faiss
import pandas as pd
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer, CrossEncoder
import torch
from datasets import load_dataset
from dotenv import load_dotenv
import pyarrow as pa
import pyarrow.compute as pc
import logging
logging.basicConfig(level=logging.DEBUG)
# Initialisations
load_dotenv()
HF_TOKEN = os.getenv('API_HF_TOKEN')
REPO_ID = "Loren/articles_database"
FAISS_REPO_ID = "Loren/articles_faiss"
FAISS_INDEX_FILE = "faiss_index.bin"
MODEL_NAME = "intfloat/multilingual-e5-small"
#CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
CROSS_ENCODER_NAME = "Alibaba-NLP/gte-multilingual-reranker-base"
revision = "a6258e9d2b1a11aa7bccdff9efde562bbca4393d"
#ROSS_ENCODER_NAME = "jinaai/jina-reranker-v2-base-multilingual"
# Téléchargement des fichiers Parquet depuis Hugging Face
articles_parquet = hf_hub_download(
repo_id=REPO_ID,
filename="articles_checked.parquet",
repo_type="dataset",
cache_dir=cache_dir)
tags_parquet = hf_hub_download(
repo_id=REPO_ID,
filename="tags.parquet",
repo_type="dataset",
cache_dir=cache_dir)
tag_article_parquet = hf_hub_download(
repo_id=REPO_ID,
filename="tag_article.parquet",
repo_type="dataset",
cache_dir=cache_dir)
# Connexion DuckDB en mémoire
con = duckdb.connect()
# Créer des tables DuckDB directement à partir des fichiers Parquet
con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
# Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
hf_faiss_index = hf_hub_download(
repo_id=FAISS_REPO_ID,
filename=FAISS_INDEX_FILE,
repo_type="dataset",
token=HF_TOKEN,
cache_dir=cache_dir
)
# Chargement de l’index FAISS
faiss_index = faiss.read_index(hf_faiss_index)
# Téléchargement des metadatas Faiss depuis le dataset Hugging Face
dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN)
arrow_table = dataset.data
# Creation du Sentence transformer model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"*** Device: {device}")
model = SentenceTransformer(MODEL_NAME, device=device)
# Création du cross-encoder
cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
revision=revision,
trust_remote_code=True)
# Fonctions d'accès aux données
def fetch_tags() -> List[str]:
"""
Récupère la liste de tous les tags disponibles dans la base de données.
Returns:
Dict: Un dictionnaire contenant le statut et les résultats.
- Si succès :
{
"status": "ok",
"result": List[str] # Liste des noms de tags triés par ordre alphabétique
}
- En cas d'erreur :
{
"status": "error",
"code": str, # Nom de l'exception
"message": str # Message de l'exception
}
"""
try:
query = "SELECT tag_name FROM tags ORDER BY tag_name"
result = con.execute(query).fetchall()
return {"status": "ok", "result": [row[0] for row in result]}
except Exception as e:
return {"status": "error", "code": type(e).__name__, "message": str(e)}
def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
"""
Récupère les articles associés à un ou plusieurs tags.
Args:
tags (List[str]): Une liste de noms de tags pour filtrer les articles.
Returns:
Dict: Un dictionnaire contenant le statut et les résultats.
- Si succès :
{
"status": "ok",
"result": List[Dict] # Liste de dictionnaires représentant les articles
}
Chaque dictionnaire contient les clés :
- 'article_id': int, ID de l'article
- 'article_title': str, Titre de l'article
- 'article_url': str, URL de l'article
- En cas d'erreur ou si aucun tag fourni :
{
"status": "error",
"code": str, # Code d'erreur ou nom de l'exception
"message": str # Message d'erreur
}
Notes:
- Si la liste `tags` est vide, la fonction retourne une liste vide.
- Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
"""
if not tags:
return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
try:
placeholders = ",".join(["?"] * len(tags))
query = f"""SELECT distinct a.article_id, a.article_title, a.article_url,
CASE WHEN a.article_online
THEN a.article_url
ELSE 'Article unavailable' END AS url,
FROM tags t, tag_article ta, articles a
WHERE t.tag_id = ta.tag_id
AND ta.article_id = a.article_id
AND t.tag_name IN ({placeholders})
"""
result = con.execute(query, tags).fetchdf()
return {"status": "ok", "result": result.to_dict(orient="records")}
except Exception as e:
return {"status": "error", "code": type(e).__name__, "message": str(e)}
def fetch_query_results(query: str, k_model: int = 10,
k_cross: int = 5, use_rerank: bool = True
) -> Dict[str, Any]:
"""
Exécute une requête de recherche sémantique avec FAISS, puis (optionnellement)
rerank avec un cross-encoder et retourne les meilleurs passages enrichis avec
des métadonnées provenant de DuckDB.
Paramètres
----------
query : str
La requête texte fournie par l'utilisateur.
k_model : int, optionnel (défaut = 10)
Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
k_cross : int, optionnel (défaut = 5)
Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
use_rerank : bool, optionnel (défaut = True)
Si False, on désactive complètement le cross-encoder et le rerank.
Retour
------
Dict[str, Any]
Un dictionnaire contenant :
- status : "ok" si succès, sinon "error"
- result : liste de résultats (si succès)
- code et message : informations d'erreur (si échec)
"""
if not query:
return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
try:
query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
distances, indices = faiss_index.search(query_vec, k_model)
# Résultats FAISS
faiss_ids_list = indices[0].tolist()
distances_list = distances[0].tolist()
# Filtrer Arrow sur les IDs trouvés
filtered_table = arrow_table.filter(
pc.is_in(arrow_table['faiss_id'],
value_set=pa.array(faiss_ids_list))
)
# Convertir Arrow → pandas pour ajouter la distance
df = filtered_table.to_pandas()
# Ajouter la distance en gardant l'ordre faiss_ids_list
distance_map = dict(zip(faiss_ids_list, distances_list))
df["distance"] = df["faiss_id"].map(distance_map)
if use_rerank:
# Cross-encoder
df["chunk_text"] = df["chunk_text"].str.replace(r'\s+', ' ', regex=True).str.strip()
top_passages = df["chunk_text"].tolist()
cross_input = [(query, p) for p in top_passages]
cross_scores = cross_encoder.predict(cross_input)
# Rerank
df["cross_score"] = cross_scores
df = df.sort_values(by="cross_score", ascending=False)
# Garder top k_cross
df_top = df.head(k_cross)
else:
df = df.sort_values(by="distance", ascending=False)
df["cross_score"] = df["distance"]
# Garder top k_model
df_top = df.head(k_model)
# Enregistrer dans DuckDB
con.register("faiss_tmp", df_top)
sql = """
SELECT
f.faiss_id,
f.document_id,
f.distance,
f.cross_score,
f.chunk_text,
a.article_title,
a.article_url,
CASE WHEN a.article_online
THEN a.article_url
ELSE 'Article unavailable' END AS url,
STRING_AGG(t.tag_name, ', ') AS tags
FROM faiss_tmp f
JOIN articles a ON f.document_id = a.article_id
JOIN tag_article ta ON a.article_id = ta.article_id
JOIN tags t ON ta.tag_id = t.tag_id
WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
a.article_title, a.article_online, a.article_url
ORDER BY AVG(f.cross_score) DESC
"""
duck_res = con.execute(sql).fetchdf()
# Liste finale de dictionnaires
list_result = duck_res.to_dict(orient="records")
return {"status": "ok", "result": list_result}
except Exception as e:
return {"status": "error", "code": type(e).__name__, "message": str(e)}