Loren commited on
Commit
c4e2e17
·
verified ·
1 Parent(s): 565d849

Update app/database.py

Browse files
Files changed (1) hide show
  1. app/database.py +252 -251
app/database.py CHANGED
@@ -1,251 +1,252 @@
1
- import os
2
- # Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être
3
- # définie avant d’importer datasets ou transformers, sinon elle sera ignorée.
4
- cache_dir = "/tmp"
5
- os.makedirs(cache_dir, exist_ok=True)
6
-
7
- # Rediriger le cache HF globalement
8
- os.environ["HF_HOME"] = cache_dir
9
- os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets")
10
- os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers")
11
-
12
- from typing import List, Dict, Any
13
- import duckdb
14
- import faiss
15
- import pandas as pd
16
- from huggingface_hub import hf_hub_download
17
- from sentence_transformers import SentenceTransformer, CrossEncoder
18
- import torch
19
- from datasets import load_dataset
20
- from dotenv import load_dotenv
21
- import pyarrow as pa
22
- import pyarrow.compute as pc
23
-
24
- # Initialisations
25
- load_dotenv()
26
- HF_TOKEN = os.getenv('API_HF_TOKEN')
27
-
28
- REPO_ID = "Loren/articles_database"
29
- FAISS_REPO_ID = "Loren/articles_faiss"
30
- FAISS_INDEX_FILE = "faiss_index.bin"
31
- MODEL_NAME = "intfloat/multilingual-e5-small"
32
- CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
33
-
34
- # Téléchargement des fichiers Parquet depuis Hugging Face
35
- articles_parquet = hf_hub_download(
36
- repo_id=REPO_ID,
37
- filename="articles_checked.parquet",
38
- repo_type="dataset",
39
- cache_dir=cache_dir)
40
- tags_parquet = hf_hub_download(
41
- repo_id=REPO_ID,
42
- filename="tags.parquet",
43
- repo_type="dataset",
44
- cache_dir=cache_dir)
45
- tag_article_parquet = hf_hub_download(
46
- repo_id=REPO_ID,
47
- filename="tag_article.parquet",
48
- repo_type="dataset",
49
- cache_dir=cache_dir)
50
-
51
- # Connexion DuckDB en mémoire
52
- con = duckdb.connect()
53
-
54
- # Créer des tables DuckDB directement à partir des fichiers Parquet
55
- con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
56
- con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
57
- con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
58
-
59
- # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
60
- hf_faiss_index = hf_hub_download(
61
- repo_id=FAISS_REPO_ID,
62
- filename=FAISS_INDEX_FILE,
63
- repo_type="dataset",
64
- token=HF_TOKEN,
65
- cache_dir=cache_dir
66
- )
67
-
68
- # Chargement de l’index FAISS
69
- faiss_index = faiss.read_index(hf_faiss_index)
70
-
71
- # Téléchargement des metadatas Faiss depuis le dataset Hugging Face
72
- dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN)
73
- arrow_table = dataset.data
74
-
75
- # Creation du Sentence transformer model
76
- device = "cuda" if torch.cuda.is_available() else "cpu"
77
- print(f"*** Device: {device}")
78
- model = SentenceTransformer(MODEL_NAME, device=device)
79
-
80
- # Création du cross-encoder
81
- cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
82
- trust_remote_code=True)
83
-
84
-
85
- # Fonctions d'accès aux données
86
-
87
- def fetch_tags() -> List[str]:
88
- """
89
- Récupère la liste de tous les tags disponibles dans la base de données.
90
-
91
- Returns:
92
- Dict: Un dictionnaire contenant le statut et les résultats.
93
- - Si succès :
94
- {
95
- "status": "ok",
96
- "result": List[str] # Liste des noms de tags triés par ordre alphabétique
97
- }
98
- - En cas d'erreur :
99
- {
100
- "status": "error",
101
- "code": str, # Nom de l'exception
102
- "message": str # Message de l'exception
103
- }
104
- """
105
- try:
106
- query = "SELECT tag_name FROM tags ORDER BY tag_name"
107
- result = con.execute(query).fetchall()
108
- return {"status": "ok", "result": [row[0] for row in result]}
109
- except Exception as e:
110
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
111
-
112
- def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
113
- """
114
- Récupère les articles associés à un ou plusieurs tags.
115
-
116
- Args:
117
- tags (List[str]): Une liste de noms de tags pour filtrer les articles.
118
-
119
- Returns:
120
- Dict: Un dictionnaire contenant le statut et les résultats.
121
- - Si succès :
122
- {
123
- "status": "ok",
124
- "result": List[Dict] # Liste de dictionnaires représentant les articles
125
- }
126
- Chaque dictionnaire contient les clés :
127
- - 'article_id': int, ID de l'article
128
- - 'article_title': str, Titre de l'article
129
- - 'article_url': str, URL de l'article
130
- - En cas d'erreur ou si aucun tag fourni :
131
- {
132
- "status": "error",
133
- "code": str, # Code d'erreur ou nom de l'exception
134
- "message": str # Message d'erreur
135
- }
136
-
137
- Notes:
138
- - Si la liste `tags` est vide, la fonction retourne une liste vide.
139
- - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
140
- """
141
- if not tags:
142
- return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
143
-
144
- try:
145
- placeholders = ",".join(["?"] * len(tags))
146
- query = f"""SELECT distinct a.article_id, a.article_title, a.article_url,
147
- CASE WHEN a.article_online
148
- THEN a.article_url
149
- ELSE 'Article unavailable' END AS url,
150
- FROM tags t, tag_article ta, articles a
151
- WHERE t.tag_id = ta.tag_id
152
- AND ta.article_id = a.article_id
153
- AND t.tag_name IN ({placeholders})
154
- """
155
- result = con.execute(query, tags).fetchdf()
156
- return {"status": "ok", "result": result.to_dict(orient="records")}
157
- except Exception as e:
158
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
159
-
160
- def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict[str, Any]:
161
- """
162
- Exécute une requête de recherche sémantique avec FAISS, puis rerank avec un cross-encoder
163
- et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB.
164
-
165
- Paramètres
166
- ----------
167
- query : str
168
- La requête texte fournie par l'utilisateur.
169
- k_model : int, optionnel (défaut = 10)
170
- Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
171
- k_cross : int, optionnel (défaut = 5)
172
- Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
173
-
174
- Retour
175
- ------
176
- Dict[str, Any]
177
- Un dictionnaire contenant :
178
- - status : "ok" si succès, sinon "error"
179
- - result : liste de résultats (si succès)
180
- - code et message : informations d'erreur (si échec)
181
- """
182
- if not query:
183
- return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
184
- try:
185
- query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
186
- distances, indices = faiss_index.search(query_vec, k_model)
187
-
188
- # Résultats FAISS
189
- faiss_ids_list = indices[0].tolist()
190
- distances_list = distances[0].tolist()
191
-
192
- # Filtrer Arrow sur les IDs trouvés
193
- filtered_table = arrow_table.filter(
194
- pc.is_in(arrow_table['faiss_id'],
195
- value_set=pa.array(faiss_ids_list))
196
- )
197
-
198
- # Convertir Arrow → pandas pour ajouter la distance
199
- df = filtered_table.to_pandas()
200
-
201
- # Ajouter la distance en gardant l'ordre faiss_ids_list
202
- distance_map = dict(zip(faiss_ids_list, distances_list))
203
- df["distance"] = df["faiss_id"].map(distance_map)
204
-
205
- # Cross-encoder
206
- top_passages = df["chunk_text"].tolist()
207
- cross_input = [(query, p) for p in top_passages]
208
- cross_scores = cross_encoder.predict(cross_input)
209
-
210
- # Rerank
211
- df["cross_score"] = cross_scores
212
- df = df.sort_values(by="cross_score", ascending=False)
213
-
214
- # Garder top k_cross
215
- df_top = df.head(k_cross)
216
-
217
- # Enregistrer dans DuckDB
218
- con.register("faiss_tmp", df_top)
219
-
220
- sql = """
221
- SELECT
222
- f.faiss_id,
223
- f.document_id,
224
- f.distance,
225
- f.cross_score,
226
- f.chunk_text,
227
- a.article_title,
228
- a.article_url,
229
- CASE WHEN a.article_online
230
- THEN a.article_url
231
- ELSE 'Article unavailable' END AS url,
232
- STRING_AGG(t.tag_name, ', ') AS tags
233
- FROM faiss_tmp f
234
- JOIN articles a ON f.document_id = a.article_id
235
- JOIN tag_article ta ON a.article_id = ta.article_id
236
- JOIN tags t ON ta.tag_id = t.tag_id
237
- WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
238
- GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
239
- a.article_title, a.article_online, a.article_url
240
- ORDER BY AVG(f.cross_score) DESC
241
- """
242
-
243
- duck_res = con.execute(sql).fetchdf()
244
-
245
- # Liste finale de dictionnaires
246
- list_result = duck_res.to_dict(orient="records")
247
-
248
- return {"status": "ok", "result": list_result}
249
- except Exception as e:
250
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
251
-
 
 
1
+ import os
2
+ # Règle d’or : toute variable d’environnement qui influence le cache Hugging Face doit être
3
+ # définie avant d’importer datasets ou transformers, sinon elle sera ignorée.
4
+ cache_dir = "/tmp"
5
+ os.makedirs(cache_dir, exist_ok=True)
6
+
7
+ # Rediriger le cache HF globalement
8
+ os.environ["HF_HOME"] = cache_dir
9
+ os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_dir, "datasets")
10
+ os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_dir, "transformers")
11
+
12
+ from typing import List, Dict, Any
13
+ import duckdb
14
+ import faiss
15
+ import pandas as pd
16
+ from huggingface_hub import hf_hub_download
17
+ from sentence_transformers import SentenceTransformer, CrossEncoder
18
+ import torch
19
+ from datasets import load_dataset
20
+ from dotenv import load_dotenv
21
+ import pyarrow as pa
22
+ import pyarrow.compute as pc
23
+
24
+ # Initialisations
25
+ load_dotenv()
26
+ HF_TOKEN = os.getenv('API_HF_TOKEN')
27
+
28
+ REPO_ID = "Loren/articles_database"
29
+ FAISS_REPO_ID = "Loren/articles_faiss"
30
+ FAISS_INDEX_FILE = "faiss_index.bin"
31
+ MODEL_NAME = "intfloat/multilingual-e5-small"
32
+ #CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
33
+ CROSS_ENCODER_NAME = "Alibaba-NLP/gte-multilingual-reranker-base"
34
+
35
+ # Téléchargement des fichiers Parquet depuis Hugging Face
36
+ articles_parquet = hf_hub_download(
37
+ repo_id=REPO_ID,
38
+ filename="articles_checked.parquet",
39
+ repo_type="dataset",
40
+ cache_dir=cache_dir)
41
+ tags_parquet = hf_hub_download(
42
+ repo_id=REPO_ID,
43
+ filename="tags.parquet",
44
+ repo_type="dataset",
45
+ cache_dir=cache_dir)
46
+ tag_article_parquet = hf_hub_download(
47
+ repo_id=REPO_ID,
48
+ filename="tag_article.parquet",
49
+ repo_type="dataset",
50
+ cache_dir=cache_dir)
51
+
52
+ # Connexion DuckDB en mémoire
53
+ con = duckdb.connect()
54
+
55
+ # Créer des tables DuckDB directement à partir des fichiers Parquet
56
+ con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
57
+ con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
58
+ con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
59
+
60
+ # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
61
+ hf_faiss_index = hf_hub_download(
62
+ repo_id=FAISS_REPO_ID,
63
+ filename=FAISS_INDEX_FILE,
64
+ repo_type="dataset",
65
+ token=HF_TOKEN,
66
+ cache_dir=cache_dir
67
+ )
68
+
69
+ # Chargement de l’index FAISS
70
+ faiss_index = faiss.read_index(hf_faiss_index)
71
+
72
+ # Téléchargement des metadatas Faiss depuis le dataset Hugging Face
73
+ dataset = load_dataset(FAISS_REPO_ID, split="train", token=HF_TOKEN)
74
+ arrow_table = dataset.data
75
+
76
+ # Creation du Sentence transformer model
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+ print(f"*** Device: {device}")
79
+ model = SentenceTransformer(MODEL_NAME, device=device)
80
+
81
+ # Création du cross-encoder
82
+ cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
83
+ trust_remote_code=True)
84
+
85
+
86
+ # Fonctions d'accès aux données
87
+
88
+ def fetch_tags() -> List[str]:
89
+ """
90
+ Récupère la liste de tous les tags disponibles dans la base de données.
91
+
92
+ Returns:
93
+ Dict: Un dictionnaire contenant le statut et les résultats.
94
+ - Si succès :
95
+ {
96
+ "status": "ok",
97
+ "result": List[str] # Liste des noms de tags triés par ordre alphabétique
98
+ }
99
+ - En cas d'erreur :
100
+ {
101
+ "status": "error",
102
+ "code": str, # Nom de l'exception
103
+ "message": str # Message de l'exception
104
+ }
105
+ """
106
+ try:
107
+ query = "SELECT tag_name FROM tags ORDER BY tag_name"
108
+ result = con.execute(query).fetchall()
109
+ return {"status": "ok", "result": [row[0] for row in result]}
110
+ except Exception as e:
111
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
112
+
113
+ def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
114
+ """
115
+ Récupère les articles associés à un ou plusieurs tags.
116
+
117
+ Args:
118
+ tags (List[str]): Une liste de noms de tags pour filtrer les articles.
119
+
120
+ Returns:
121
+ Dict: Un dictionnaire contenant le statut et les résultats.
122
+ - Si succès :
123
+ {
124
+ "status": "ok",
125
+ "result": List[Dict] # Liste de dictionnaires représentant les articles
126
+ }
127
+ Chaque dictionnaire contient les clés :
128
+ - 'article_id': int, ID de l'article
129
+ - 'article_title': str, Titre de l'article
130
+ - 'article_url': str, URL de l'article
131
+ - En cas d'erreur ou si aucun tag fourni :
132
+ {
133
+ "status": "error",
134
+ "code": str, # Code d'erreur ou nom de l'exception
135
+ "message": str # Message d'erreur
136
+ }
137
+
138
+ Notes:
139
+ - Si la liste `tags` est vide, la fonction retourne une liste vide.
140
+ - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
141
+ """
142
+ if not tags:
143
+ return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
144
+
145
+ try:
146
+ placeholders = ",".join(["?"] * len(tags))
147
+ query = f"""SELECT distinct a.article_id, a.article_title, a.article_url,
148
+ CASE WHEN a.article_online
149
+ THEN a.article_url
150
+ ELSE 'Article unavailable' END AS url,
151
+ FROM tags t, tag_article ta, articles a
152
+ WHERE t.tag_id = ta.tag_id
153
+ AND ta.article_id = a.article_id
154
+ AND t.tag_name IN ({placeholders})
155
+ """
156
+ result = con.execute(query, tags).fetchdf()
157
+ return {"status": "ok", "result": result.to_dict(orient="records")}
158
+ except Exception as e:
159
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
160
+
161
+ def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict[str, Any]:
162
+ """
163
+ Exécute une requête de recherche sémantique avec FAISS, puis rerank avec un cross-encoder
164
+ et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB.
165
+
166
+ Paramètres
167
+ ----------
168
+ query : str
169
+ La requête texte fournie par l'utilisateur.
170
+ k_model : int, optionnel (défaut = 10)
171
+ Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
172
+ k_cross : int, optionnel (défaut = 5)
173
+ Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
174
+
175
+ Retour
176
+ ------
177
+ Dict[str, Any]
178
+ Un dictionnaire contenant :
179
+ - status : "ok" si succès, sinon "error"
180
+ - result : liste de résultats (si succès)
181
+ - code et message : informations d'erreur (si échec)
182
+ """
183
+ if not query:
184
+ return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
185
+ try:
186
+ query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
187
+ distances, indices = faiss_index.search(query_vec, k_model)
188
+
189
+ # Résultats FAISS
190
+ faiss_ids_list = indices[0].tolist()
191
+ distances_list = distances[0].tolist()
192
+
193
+ # Filtrer Arrow sur les IDs trouvés
194
+ filtered_table = arrow_table.filter(
195
+ pc.is_in(arrow_table['faiss_id'],
196
+ value_set=pa.array(faiss_ids_list))
197
+ )
198
+
199
+ # Convertir Arrow → pandas pour ajouter la distance
200
+ df = filtered_table.to_pandas()
201
+
202
+ # Ajouter la distance en gardant l'ordre faiss_ids_list
203
+ distance_map = dict(zip(faiss_ids_list, distances_list))
204
+ df["distance"] = df["faiss_id"].map(distance_map)
205
+
206
+ # Cross-encoder
207
+ top_passages = df["chunk_text"].tolist()
208
+ cross_input = [(query, p) for p in top_passages]
209
+ cross_scores = cross_encoder.predict(cross_input)
210
+
211
+ # Rerank
212
+ df["cross_score"] = cross_scores
213
+ df = df.sort_values(by="cross_score", ascending=False)
214
+
215
+ # Garder top k_cross
216
+ df_top = df.head(k_cross)
217
+
218
+ # Enregistrer dans DuckDB
219
+ con.register("faiss_tmp", df_top)
220
+
221
+ sql = """
222
+ SELECT
223
+ f.faiss_id,
224
+ f.document_id,
225
+ f.distance,
226
+ f.cross_score,
227
+ f.chunk_text,
228
+ a.article_title,
229
+ a.article_url,
230
+ CASE WHEN a.article_online
231
+ THEN a.article_url
232
+ ELSE 'Article unavailable' END AS url,
233
+ STRING_AGG(t.tag_name, ', ') AS tags
234
+ FROM faiss_tmp f
235
+ JOIN articles a ON f.document_id = a.article_id
236
+ JOIN tag_article ta ON a.article_id = ta.article_id
237
+ JOIN tags t ON ta.tag_id = t.tag_id
238
+ WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
239
+ GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
240
+ a.article_title, a.article_online, a.article_url
241
+ ORDER BY AVG(f.cross_score) DESC
242
+ """
243
+
244
+ duck_res = con.execute(sql).fetchdf()
245
+
246
+ # Liste finale de dictionnaires
247
+ list_result = duck_res.to_dict(orient="records")
248
+
249
+ return {"status": "ok", "result": list_result}
250
+ except Exception as e:
251
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
252
+