Loren commited on
Commit
0ac5477
·
verified ·
1 Parent(s): 9a517f4

Upload database.py

Browse files
Files changed (1) hide show
  1. app/database.py +256 -252
app/database.py CHANGED
@@ -1,252 +1,256 @@
1
- import os
2
- # Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être
3
- # définie avant d’importer datasets ou transformers, sinon elle sera ignorée.
4
- cache_dir = "/tmp"
5
- os.makedirs(cache_dir, exist_ok=True)
6
-
7
- # Rediriger le cache HF globalement
8
- os.environ["HF_HOME"] = cache_dir
9
- os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets")
10
- os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers")
11
-
12
- from typing import List, Dict, Any
13
- import duckdb
14
- import faiss
15
- import pandas as pd
16
- from huggingface_hub import hf_hub_download
17
- from sentence_transformers import SentenceTransformer, CrossEncoder
18
- import torch
19
- from datasets import load_dataset
20
- from dotenv import load_dotenv
21
- import pyarrow as pa
22
- import pyarrow.compute as pc
23
-
24
- # Initialisations
25
- load_dotenv()
26
- HF_TOKEN = os.getenv('API_HF_TOKEN')
27
-
28
- REPO_ID = "Loren/articles_database"
29
- FAISS_REPO_ID = "Loren/articles_faiss"
30
- FAISS_INDEX_FILE = "faiss_index.bin"
31
- MODEL_NAME = "intfloat/multilingual-e5-small"
32
- #CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
33
- CROSS_ENCODER_NAME = "Alibaba-NLP/gte-multilingual-reranker-base"
34
-
35
- # Téléchargement des fichiers Parquet depuis Hugging Face
36
- articles_parquet = hf_hub_download(
37
- repo_id=REPO_ID,
38
- filename="articles_checked.parquet",
39
- repo_type="dataset",
40
- cache_dir=cache_dir)
41
- tags_parquet = hf_hub_download(
42
- repo_id=REPO_ID,
43
- filename="tags.parquet",
44
- repo_type="dataset",
45
- cache_dir=cache_dir)
46
- tag_article_parquet = hf_hub_download(
47
- repo_id=REPO_ID,
48
- filename="tag_article.parquet",
49
- repo_type="dataset",
50
- cache_dir=cache_dir)
51
-
52
- # Connexion DuckDB en mémoire
53
- con = duckdb.connect()
54
-
55
- # Créer des tables DuckDB directement à partir des fichiers Parquet
56
- con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
57
- con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
58
- con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
59
-
60
- # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
61
- hf_faiss_index = hf_hub_download(
62
- repo_id=FAISS_REPO_ID,
63
- filename=FAISS_INDEX_FILE,
64
- repo_type="dataset",
65
- token=HF_TOKEN,
66
- cache_dir=cache_dir
67
- )
68
-
69
- # Chargement de l’index FAISS
70
- faiss_index = faiss.read_index(hf_faiss_index)
71
-
72
- # Téléchargement des metadatas Faiss depuis le dataset Hugging Face
73
- dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN)
74
- arrow_table = dataset.data
75
-
76
- # Creation du Sentence transformer model
77
- device = "cuda" if torch.cuda.is_available() else "cpu"
78
- print(f"*** Device: {device}")
79
- model = SentenceTransformer(MODEL_NAME, device=device)
80
-
81
- # Création du cross-encoder
82
- cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
83
- trust_remote_code=True)
84
-
85
-
86
- # Fonctions d'accès aux données
87
-
88
- def fetch_tags() -> List[str]:
89
- """
90
- Récupère la liste de tous les tags disponibles dans la base de données.
91
-
92
- Returns:
93
- Dict: Un dictionnaire contenant le statut et les résultats.
94
- - Si succès :
95
- {
96
- "status": "ok",
97
- "result": List[str] # Liste des noms de tags triés par ordre alphabétique
98
- }
99
- - En cas d'erreur :
100
- {
101
- "status": "error",
102
- "code": str, # Nom de l'exception
103
- "message": str # Message de l'exception
104
- }
105
- """
106
- try:
107
- query = "SELECT tag_name FROM tags ORDER BY tag_name"
108
- result = con.execute(query).fetchall()
109
- return {"status": "ok", "result": [row[0] for row in result]}
110
- except Exception as e:
111
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
112
-
113
- def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
114
- """
115
- Récupère les articles associés à un ou plusieurs tags.
116
-
117
- Args:
118
- tags (List[str]): Une liste de noms de tags pour filtrer les articles.
119
-
120
- Returns:
121
- Dict: Un dictionnaire contenant le statut et les résultats.
122
- - Si succès :
123
- {
124
- "status": "ok",
125
- "result": List[Dict] # Liste de dictionnaires représentant les articles
126
- }
127
- Chaque dictionnaire contient les clés :
128
- - 'article_id': int, ID de l'article
129
- - 'article_title': str, Titre de l'article
130
- - 'article_url': str, URL de l'article
131
- - En cas d'erreur ou si aucun tag fourni :
132
- {
133
- "status": "error",
134
- "code": str, # Code d'erreur ou nom de l'exception
135
- "message": str # Message d'erreur
136
- }
137
-
138
- Notes:
139
- - Si la liste `tags` est vide, la fonction retourne une liste vide.
140
- - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
141
- """
142
- if not tags:
143
- return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
144
-
145
- try:
146
- placeholders = ",".join(["?"] * len(tags))
147
- query = f"""SELECT distinct a.article_id, a.article_title, a.article_url,
148
- CASE WHEN a.article_online
149
- THEN a.article_url
150
- ELSE 'Article unavailable' END AS url,
151
- FROM tags t, tag_article ta, articles a
152
- WHERE t.tag_id = ta.tag_id
153
- AND ta.article_id = a.article_id
154
- AND t.tag_name IN ({placeholders})
155
- """
156
- result = con.execute(query, tags).fetchdf()
157
- return {"status": "ok", "result": result.to_dict(orient="records")}
158
- except Exception as e:
159
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
160
-
161
- def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict[str, Any]:
162
- """
163
- Exécute une requête de recherche sémantique avec FAISS, puis rerank avec un cross-encoder
164
- et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB.
165
-
166
- Paramètres
167
- ----------
168
- query : str
169
- La requête texte fournie par l'utilisateur.
170
- k_model : int, optionnel (défaut = 10)
171
- Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
172
- k_cross : int, optionnel (défaut = 5)
173
- Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
174
-
175
- Retour
176
- ------
177
- Dict[str, Any]
178
- Un dictionnaire contenant :
179
- - status : "ok" si succès, sinon "error"
180
- - result : liste de résultats (si succès)
181
- - code et message : informations d'erreur (si échec)
182
- """
183
- if not query:
184
- return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
185
- try:
186
- query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
187
- distances, indices = faiss_index.search(query_vec, k_model)
188
-
189
- # Résultats FAISS
190
- faiss_ids_list = indices[0].tolist()
191
- distances_list = distances[0].tolist()
192
-
193
- # Filtrer Arrow sur les IDs trouvés
194
- filtered_table = arrow_table.filter(
195
- pc.is_in(arrow_table['faiss_id'],
196
- value_set=pa.array(faiss_ids_list))
197
- )
198
-
199
- # Convertir Arrow → pandas pour ajouter la distance
200
- df = filtered_table.to_pandas()
201
-
202
- # Ajouter la distance en gardant l'ordre faiss_ids_list
203
- distance_map = dict(zip(faiss_ids_list, distances_list))
204
- df["distance"] = df["faiss_id"].map(distance_map)
205
-
206
- # Cross-encoder
207
- top_passages = df["chunk_text"].tolist()
208
- cross_input = [(query, p) for p in top_passages]
209
- cross_scores = cross_encoder.predict(cross_input)
210
-
211
- # Rerank
212
- df["cross_score"] = cross_scores
213
- df = df.sort_values(by="cross_score", ascending=False)
214
-
215
- # Garder top k_cross
216
- df_top = df.head(k_cross)
217
-
218
- # Enregistrer dans DuckDB
219
- con.register("faiss_tmp", df_top)
220
-
221
- sql = """
222
- SELECT
223
- f.faiss_id,
224
- f.document_id,
225
- f.distance,
226
- f.cross_score,
227
- f.chunk_text,
228
- a.article_title,
229
- a.article_url,
230
- CASE WHEN a.article_online
231
- THEN a.article_url
232
- ELSE 'Article unavailable' END AS url,
233
- STRING_AGG(t.tag_name, ', ') AS tags
234
- FROM faiss_tmp f
235
- JOIN articles a ON f.document_id = a.article_id
236
- JOIN tag_article ta ON a.article_id = ta.article_id
237
- JOIN tags t ON ta.tag_id = t.tag_id
238
- WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
239
- GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
240
- a.article_title, a.article_online, a.article_url
241
- ORDER BY AVG(f.cross_score) DESC
242
- """
243
-
244
- duck_res = con.execute(sql).fetchdf()
245
-
246
- # Liste finale de dictionnaires
247
- list_result = duck_res.to_dict(orient="records")
248
-
249
- return {"status": "ok", "result": list_result}
250
- except Exception as e:
251
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
252
-
 
 
 
 
 
1
+ import os
2
+ # Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être
3
+ # définie avant d’importer datasets ou transformers, sinon elle sera ignorée.
4
+ cache_dir = "/tmp"
5
+ os.makedirs(cache_dir, exist_ok=True)
6
+
7
+ # Rediriger le cache HF globalement
8
+ os.environ["HF_HOME"] = cache_dir
9
+ os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets")
10
+ os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers")
11
+
12
+ from typing import List, Dict, Any
13
+ import duckdb
14
+ import faiss
15
+ import pandas as pd
16
+ from huggingface_hub import hf_hub_download
17
+ from sentence_transformers import SentenceTransformer, CrossEncoder
18
+ import torch
19
+ from datasets import load_dataset
20
+ from dotenv import load_dotenv
21
+ import pyarrow as pa
22
+ import pyarrow.compute as pc
23
+
24
+ # Initialisations
25
+ load_dotenv()
26
+ HF_TOKEN = os.getenv('API_HF_TOKEN')
27
+
28
+ REPO_ID = "Loren/articles_database"
29
+ FAISS_REPO_ID = "Loren/articles_faiss"
30
+ FAISS_INDEX_FILE = "faiss_index.bin"
31
+ MODEL_NAME = "intfloat/multilingual-e5-small"
32
+ #CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
33
+ CROSS_ENCODER_NAME = "Alibaba-NLP/gte-multilingual-reranker-base"
34
+ revision = "a6258e9d2b1a11aa7bccdff9efde562bbca4393d"
35
+ #ROSS_ENCODER_NAME = "jinaai/jina-reranker-v2-base-multilingual"
36
+
37
+ # Téléchargement des fichiers Parquet depuis Hugging Face
38
+ articles_parquet = hf_hub_download(
39
+ repo_id=REPO_ID,
40
+ filename="articles_checked.parquet",
41
+ repo_type="dataset",
42
+ cache_dir=cache_dir)
43
+ tags_parquet = hf_hub_download(
44
+ repo_id=REPO_ID,
45
+ filename="tags.parquet",
46
+ repo_type="dataset",
47
+ cache_dir=cache_dir)
48
+ tag_article_parquet = hf_hub_download(
49
+ repo_id=REPO_ID,
50
+ filename="tag_article.parquet",
51
+ repo_type="dataset",
52
+ cache_dir=cache_dir)
53
+
54
+ # Connexion DuckDB en mémoire
55
+ con = duckdb.connect()
56
+
57
+ # Créer des tables DuckDB directement à partir des fichiers Parquet
58
+ con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
59
+ con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
60
+ con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
61
+
62
+ # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
63
+ hf_faiss_index = hf_hub_download(
64
+ repo_id=FAISS_REPO_ID,
65
+ filename=FAISS_INDEX_FILE,
66
+ repo_type="dataset",
67
+ token=HF_TOKEN,
68
+ cache_dir=cache_dir
69
+ )
70
+
71
+ # Chargement de l’index FAISS
72
+ faiss_index = faiss.read_index(hf_faiss_index)
73
+
74
+ # Téléchargement des metadatas Faiss depuis le dataset Hugging Face
75
+ dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN)
76
+ arrow_table = dataset.data
77
+
78
+ # Creation du Sentence transformer model
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ print(f"*** Device: {device}")
81
+ model = SentenceTransformer(MODEL_NAME, device=device)
82
+
83
+ # Création du cross-encoder
84
+ cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
85
+ revision=revision,
86
+ trust_remote_code=True)
87
+
88
+
89
+ # Fonctions d'accès aux données
90
+
91
+ def fetch_tags() -> List[str]:
92
+ """
93
+ Récupère la liste de tous les tags disponibles dans la base de données.
94
+
95
+ Returns:
96
+ Dict: Un dictionnaire contenant le statut et les résultats.
97
+ - Si succès :
98
+ {
99
+ "status": "ok",
100
+ "result": List[str] # Liste des noms de tags triés par ordre alphabétique
101
+ }
102
+ - En cas d'erreur :
103
+ {
104
+ "status": "error",
105
+ "code": str, # Nom de l'exception
106
+ "message": str # Message de l'exception
107
+ }
108
+ """
109
+ try:
110
+ query = "SELECT tag_name FROM tags ORDER BY tag_name"
111
+ result = con.execute(query).fetchall()
112
+ return {"status": "ok", "result": [row[0] for row in result]}
113
+ except Exception as e:
114
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
115
+
116
+ def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
117
+ """
118
+ Récupère les articles associés à un ou plusieurs tags.
119
+
120
+ Args:
121
+ tags (List[str]): Une liste de noms de tags pour filtrer les articles.
122
+
123
+ Returns:
124
+ Dict: Un dictionnaire contenant le statut et les résultats.
125
+ - Si succès :
126
+ {
127
+ "status": "ok",
128
+ "result": List[Dict] # Liste de dictionnaires représentant les articles
129
+ }
130
+ Chaque dictionnaire contient les clés :
131
+ - 'article_id': int, ID de l'article
132
+ - 'article_title': str, Titre de l'article
133
+ - 'article_url': str, URL de l'article
134
+ - En cas d'erreur ou si aucun tag fourni :
135
+ {
136
+ "status": "error",
137
+ "code": str, # Code d'erreur ou nom de l'exception
138
+ "message": str # Message d'erreur
139
+ }
140
+
141
+ Notes:
142
+ - Si la liste `tags` est vide, la fonction retourne une liste vide.
143
+ - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
144
+ """
145
+ if not tags:
146
+ return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
147
+
148
+ try:
149
+ placeholders = ",".join(["?"] * len(tags))
150
+ query = f"""SELECT distinct a.article_id, a.article_title, a.article_url,
151
+ CASE WHEN a.article_online
152
+ THEN a.article_url
153
+ ELSE 'Article unavailable' END AS url,
154
+ FROM tags t, tag_article ta, articles a
155
+ WHERE t.tag_id = ta.tag_id
156
+ AND ta.article_id = a.article_id
157
+ AND t.tag_name IN ({placeholders})
158
+ """
159
+ result = con.execute(query, tags).fetchdf()
160
+ return {"status": "ok", "result": result.to_dict(orient="records")}
161
+ except Exception as e:
162
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
163
+
164
+ def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict[str, Any]:
165
+ """
166
+ Exécute une requête de recherche sémantique avec FAISS, puis rerank avec un cross-encoder
167
+ et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB.
168
+
169
+ Paramètres
170
+ ----------
171
+ query : str
172
+ La requête texte fournie par l'utilisateur.
173
+ k_model : int, optionnel (défaut = 10)
174
+ Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
175
+ k_cross : int, optionnel (défaut = 5)
176
+ Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
177
+
178
+ Retour
179
+ ------
180
+ Dict[str, Any]
181
+ Un dictionnaire contenant :
182
+ - status : "ok" si succès, sinon "error"
183
+ - result : liste de résultats (si succès)
184
+ - code et message : informations d'erreur (si échec)
185
+ """
186
+ if not query:
187
+ return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
188
+ try:
189
+ query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
190
+ distances, indices = faiss_index.search(query_vec, k_model)
191
+
192
+ # Résultats FAISS
193
+ faiss_ids_list = indices[0].tolist()
194
+ distances_list = distances[0].tolist()
195
+
196
+ # Filtrer Arrow sur les IDs trouvés
197
+ filtered_table = arrow_table.filter(
198
+ pc.is_in(arrow_table['faiss_id'],
199
+ value_set=pa.array(faiss_ids_list))
200
+ )
201
+
202
+ # Convertir Arrow pandas pour ajouter la distance
203
+ df = filtered_table.to_pandas()
204
+
205
+ # Ajouter la distance en gardant l'ordre faiss_ids_list
206
+ distance_map = dict(zip(faiss_ids_list, distances_list))
207
+ df["distance"] = df["faiss_id"].map(distance_map)
208
+
209
+ # Cross-encoder
210
+ df["chunk_text"] = df["chunk_text"].str.replace(r'\s+', ' ', regex=True).str.strip()
211
+ top_passages = df["chunk_text"].tolist()
212
+ cross_input = [(query, p) for p in top_passages]
213
+ cross_scores = cross_encoder.predict(cross_input)
214
+
215
+ # Rerank
216
+ df["cross_score"] = cross_scores
217
+ df = df.sort_values(by="cross_score", ascending=False)
218
+
219
+ # Garder top k_cross
220
+ df_top = df.head(k_cross)
221
+
222
+ # Enregistrer dans DuckDB
223
+ con.register("faiss_tmp", df_top)
224
+
225
+ sql = """
226
+ SELECT
227
+ f.faiss_id,
228
+ f.document_id,
229
+ f.distance,
230
+ f.cross_score,
231
+ f.chunk_text,
232
+ a.article_title,
233
+ a.article_url,
234
+ CASE WHEN a.article_online
235
+ THEN a.article_url
236
+ ELSE 'Article unavailable' END AS url,
237
+ STRING_AGG(t.tag_name, ', ') AS tags
238
+ FROM faiss_tmp f
239
+ JOIN articles a ON f.document_id = a.article_id
240
+ JOIN tag_article ta ON a.article_id = ta.article_id
241
+ JOIN tags t ON ta.tag_id = t.tag_id
242
+ WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
243
+ GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
244
+ a.article_title, a.article_online, a.article_url
245
+ ORDER BY AVG(f.cross_score) DESC
246
+ """
247
+
248
+ duck_res = con.execute(sql).fetchdf()
249
+
250
+ # Liste finale de dictionnaires
251
+ list_result = duck_res.to_dict(orient="records")
252
+
253
+ return {"status": "ok", "result": list_result}
254
+ except Exception as e:
255
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
256
+