Loren commited on
Commit
d721f56
·
verified ·
1 Parent(s): 1751d36

Delete database.py

Browse files
Files changed (1) hide show
  1. database.py +0 -272
database.py DELETED
@@ -1,272 +0,0 @@
1
- import os
2
- # Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être
3
- # définie avant d’importer datasets ou transformers, sinon elle sera ignorée.
4
- cache_dir = "/tmp"
5
- os.makedirs(cache_dir, exist_ok=True)
6
-
7
- # Rediriger le cache HF globalement
8
- os.environ["HF_HOME"] = cache_dir
9
- os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets")
10
- os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers")
11
-
12
- from typing import List, Dict, Any
13
- import duckdb
14
- import faiss
15
- import pandas as pd
16
- from huggingface_hub import hf_hub_download
17
- from sentence_transformers import SentenceTransformer, CrossEncoder
18
- import torch
19
- from datasets import load_dataset
20
- from dotenv import load_dotenv
21
- import pyarrow as pa
22
- import pyarrow.compute as pc
23
-
24
- import logging
25
- logging.basicConfig(level=logging.DEBUG)
26
-
27
- # Initialisations
28
- load_dotenv()
29
- HF_TOKEN = os.getenv('API_HF_TOKEN')
30
-
31
- REPO_ID = "Loren/articles_database"
32
- FAISS_REPO_ID = "Loren/articles_faiss"
33
- FAISS_INDEX_FILE = "faiss_index.bin"
34
- MODEL_NAME = "intfloat/multilingual-e5-small"
35
- #CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
36
- CROSS_ENCODER_NAME = "Alibaba-NLP/gte-multilingual-reranker-base"
37
- revision = "a6258e9d2b1a11aa7bccdff9efde562bbca4393d"
38
- #ROSS_ENCODER_NAME = "jinaai/jina-reranker-v2-base-multilingual"
39
-
40
- # Téléchargement des fichiers Parquet depuis Hugging Face
41
- articles_parquet = hf_hub_download(
42
- repo_id=REPO_ID,
43
- filename="articles_checked.parquet",
44
- repo_type="dataset",
45
- cache_dir=cache_dir)
46
- tags_parquet = hf_hub_download(
47
- repo_id=REPO_ID,
48
- filename="tags.parquet",
49
- repo_type="dataset",
50
- cache_dir=cache_dir)
51
- tag_article_parquet = hf_hub_download(
52
- repo_id=REPO_ID,
53
- filename="tag_article.parquet",
54
- repo_type="dataset",
55
- cache_dir=cache_dir)
56
-
57
- # Connexion DuckDB en mémoire
58
- con = duckdb.connect()
59
-
60
- # Créer des tables DuckDB directement à partir des fichiers Parquet
61
- con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
62
- con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
63
- con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
64
-
65
- # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
66
- hf_faiss_index = hf_hub_download(
67
- repo_id=FAISS_REPO_ID,
68
- filename=FAISS_INDEX_FILE,
69
- repo_type="dataset",
70
- token=HF_TOKEN,
71
- cache_dir=cache_dir
72
- )
73
-
74
- # Chargement de l’index FAISS
75
- faiss_index = faiss.read_index(hf_faiss_index)
76
-
77
- # Téléchargement des metadatas Faiss depuis le dataset Hugging Face
78
- dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN)
79
- arrow_table = dataset.data
80
-
81
- # Creation du Sentence transformer model
82
- device = "cuda" if torch.cuda.is_available() else "cpu"
83
- print(f"*** Device: {device}")
84
- model = SentenceTransformer(MODEL_NAME, device=device)
85
-
86
- # Création du cross-encoder
87
- cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
88
- revision=revision,
89
- trust_remote_code=True)
90
-
91
-
92
- # Fonctions d'accès aux données
93
-
94
- def fetch_tags() -> List[str]:
95
- """
96
- Récupère la liste de tous les tags disponibles dans la base de données.
97
-
98
- Returns:
99
- Dict: Un dictionnaire contenant le statut et les résultats.
100
- - Si succès :
101
- {
102
- "status": "ok",
103
- "result": List[str] # Liste des noms de tags triés par ordre alphabétique
104
- }
105
- - En cas d'erreur :
106
- {
107
- "status": "error",
108
- "code": str, # Nom de l'exception
109
- "message": str # Message de l'exception
110
- }
111
- """
112
- try:
113
- query = "SELECT tag_name FROM tags ORDER BY tag_name"
114
- result = con.execute(query).fetchall()
115
- return {"status": "ok", "result": [row[0] for row in result]}
116
- except Exception as e:
117
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
118
-
119
- def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
120
- """
121
- Récupère les articles associés à un ou plusieurs tags.
122
-
123
- Args:
124
- tags (List[str]): Une liste de noms de tags pour filtrer les articles.
125
-
126
- Returns:
127
- Dict: Un dictionnaire contenant le statut et les résultats.
128
- - Si succès :
129
- {
130
- "status": "ok",
131
- "result": List[Dict] # Liste de dictionnaires représentant les articles
132
- }
133
- Chaque dictionnaire contient les clés :
134
- - 'article_id': int, ID de l'article
135
- - 'article_title': str, Titre de l'article
136
- - 'article_url': str, URL de l'article
137
- - En cas d'erreur ou si aucun tag fourni :
138
- {
139
- "status": "error",
140
- "code": str, # Code d'erreur ou nom de l'exception
141
- "message": str # Message d'erreur
142
- }
143
-
144
- Notes:
145
- - Si la liste `tags` est vide, la fonction retourne une liste vide.
146
- - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
147
- """
148
- if not tags:
149
- return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
150
-
151
- try:
152
- placeholders = ",".join(["?"] * len(tags))
153
- query = f"""SELECT distinct a.article_id, a.article_title, a.article_url,
154
- CASE WHEN a.article_online
155
- THEN a.article_url
156
- ELSE 'Article unavailable' END AS url,
157
- FROM tags t, tag_article ta, articles a
158
- WHERE t.tag_id = ta.tag_id
159
- AND ta.article_id = a.article_id
160
- AND t.tag_name IN ({placeholders})
161
- """
162
- result = con.execute(query, tags).fetchdf()
163
- return {"status": "ok", "result": result.to_dict(orient="records")}
164
- except Exception as e:
165
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
166
-
167
- def fetch_query_results(query: str, k_model: int = 10,
168
- k_cross: int = 5, use_rerank: bool = True
169
- ) -> Dict[str, Any]:
170
- """
171
- Exécute une requête de recherche sémantique avec FAISS, puis (optionnellement)
172
- rerank avec un cross-encoder et retourne les meilleurs passages enrichis avec
173
- des métadonnées provenant de DuckDB.
174
-
175
- Paramètres
176
- ----------
177
- query : str
178
- La requête texte fournie par l'utilisateur.
179
- k_model : int, optionnel (défaut = 10)
180
- Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
181
- k_cross : int, optionnel (défaut = 5)
182
- Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
183
- use_rerank : bool, optionnel (défaut = True)
184
- Si False, on désactive complètement le cross-encoder et le rerank.
185
-
186
- Retour
187
- ------
188
- Dict[str, Any]
189
- Un dictionnaire contenant :
190
- - status : "ok" si succès, sinon "error"
191
- - result : liste de résultats (si succès)
192
- - code et message : informations d'erreur (si échec)
193
- """
194
- if not query:
195
- return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
196
- try:
197
- query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
198
- distances, indices = faiss_index.search(query_vec, k_model)
199
-
200
- # Résultats FAISS
201
- faiss_ids_list = indices[0].tolist()
202
- distances_list = distances[0].tolist()
203
-
204
- # Filtrer Arrow sur les IDs trouvés
205
- filtered_table = arrow_table.filter(
206
- pc.is_in(arrow_table['faiss_id'],
207
- value_set=pa.array(faiss_ids_list))
208
- )
209
-
210
- # Convertir Arrow → pandas pour ajouter la distance
211
- df = filtered_table.to_pandas()
212
-
213
- # Ajouter la distance en gardant l'ordre faiss_ids_list
214
- distance_map = dict(zip(faiss_ids_list, distances_list))
215
- df["distance"] = df["faiss_id"].map(distance_map)
216
-
217
- if use_rerank:
218
- status_dbg = "ok_rerank"
219
- # Cross-encoder
220
- df["chunk_text"] = df["chunk_text"].str.replace(r'\s+', ' ', regex=True).str.strip()
221
- top_passages = df["chunk_text"].tolist()
222
- cross_input = [(query, p) for p in top_passages]
223
- cross_scores = cross_encoder.predict(cross_input)
224
-
225
- # Rerank
226
- df["cross_score"] = cross_scores
227
- df = df.sort_values(by="cross_score", ascending=False)
228
-
229
- # Garder top k_cross
230
- df_top = df.head(k_cross)
231
- else:
232
- status_dbg = "ok_no_rerank"
233
- df = df.sort_values(by="distance", ascending=False)
234
- df["cross_score"] = df["distance"]
235
- # Garder top k_model
236
- df_top = df.head(k_model)
237
-
238
- # Enregistrer dans DuckDB
239
- con.register("faiss_tmp", df_top)
240
-
241
- sql = """
242
- SELECT
243
- f.faiss_id,
244
- f.document_id,
245
- f.distance,
246
- f.cross_score,
247
- f.chunk_text,
248
- a.article_title,
249
- a.article_url,
250
- CASE WHEN a.article_online
251
- THEN a.article_url
252
- ELSE 'Article unavailable' END AS url,
253
- STRING_AGG(t.tag_name, ', ') AS tags
254
- FROM faiss_tmp f
255
- JOIN articles a ON f.document_id = a.article_id
256
- JOIN tag_article ta ON a.article_id = ta.article_id
257
- JOIN tags t ON ta.tag_id = t.tag_id
258
- WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
259
- GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
260
- a.article_title, a.article_online, a.article_url
261
- ORDER BY AVG(f.cross_score) DESC
262
- """
263
-
264
- duck_res = con.execute(sql).fetchdf()
265
-
266
- # Liste finale de dictionnaires
267
- list_result = duck_res.to_dict(orient="records")
268
-
269
- return {"status": status_dbg, "result": list_result}
270
- except Exception as e:
271
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
272
-