Loren commited on
Commit
2117ac7
·
verified ·
1 Parent(s): d8b8ff8

Update app/database.py

Browse files
Files changed (1) hide show
  1. app/database.py +243 -243
app/database.py CHANGED
@@ -1,243 +1,243 @@
1
- import os
2
- from typing import List, Dict, Any
3
- import duckdb
4
- import faiss
5
- import pandas as pd
6
- from huggingface_hub import hf_hub_download
7
- from sentence_transformers import SentenceTransformer, CrossEncoder
8
- import torch
9
- from datasets import load_dataset
10
- from dotenv import load_dotenv
11
- import pyarrow as pa
12
- import pyarrow.compute as pc
13
-
14
- # Initialisations
15
- load_dotenv()
16
- HF_TOKEN = os.getenv('HF_TOKEN')
17
-
18
- REPO_ID = "Loren/articles_database"
19
- FAISS_REPO_ID = "Loren/articles_faiss"
20
- FAISS_INDEX_FILE = "faiss_index.bin"
21
- MODEL_NAME = "intfloat/multilingual-e5-small"
22
- CROSS_ENCODER_NAME = "jinaai/jina-reranker-v2-base-multilingual"
23
-
24
- cache_dir = "/tmp"
25
- os.makedirs(cache_dir, exist_ok=True)
26
- # Rediriger le cache HF globalement
27
- os.environ["HF_HOME"] = cache_dir
28
- os.environ["HF_DATASETS_CACHE"] = cache_dir
29
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
30
-
31
- # Téléchargement des fichiers Parquet depuis Hugging Face
32
- articles_parquet = hf_hub_download(
33
- repo_id=REPO_ID,
34
- filename="articles.parquet",
35
- repo_type="dataset",
36
- cache_dir=cache_dir)
37
- tags_parquet = hf_hub_download(
38
- repo_id=REPO_ID,
39
- filename="tags.parquet",
40
- repo_type="dataset",
41
- cache_dir=cache_dir)
42
- tag_article_parquet = hf_hub_download(
43
- repo_id=REPO_ID,
44
- filename="tag_article.parquet",
45
- repo_type="dataset",
46
- cache_dir=cache_dir)
47
-
48
- # Connexion DuckDB en mémoire
49
- con = duckdb.connect()
50
-
51
- # Créer des tables DuckDB directement à partir des fichiers Parquet
52
- con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
53
- con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
54
- con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
55
-
56
- # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
57
- hf_faiss_index = hf_hub_download(
58
- repo_id=FAISS_REPO_ID,
59
- filename=FAISS_INDEX_FILE,
60
- repo_type="dataset",
61
- token=HF_TOKEN,
62
- cache_dir=cache_dir
63
- )
64
-
65
- # Chargement de l’index FAISS
66
- faiss_index = faiss.read_index(hf_faiss_index)
67
-
68
- # Téléchargement des metadatas Faiss depuis le dataset Hugging Face
69
- dataset = load_dataset(FAISS_REPO_ID, split="train")
70
- arrow_table = dataset.data
71
-
72
- # Creation du Sentence transformer model
73
- device = "cuda" if torch.cuda.is_available() else "cpu"
74
- print(f"*** Device: {device}")
75
- model = SentenceTransformer(MODEL_NAME, device=device)
76
-
77
- # Création du cross-encoder
78
- cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
79
- trust_remote_code=True)
80
-
81
-
82
- # Fonctions d'accès aux données
83
-
84
- def fetch_tags() -> List[str]:
85
- """
86
- Récupère la liste de tous les tags disponibles dans la base de données.
87
-
88
- Returns:
89
- Dict: Un dictionnaire contenant le statut et les résultats.
90
- - Si succès :
91
- {
92
- "status": "ok",
93
- "result": List[str] # Liste des noms de tags triés par ordre alphabétique
94
- }
95
- - En cas d'erreur :
96
- {
97
- "status": "error",
98
- "code": str, # Nom de l'exception
99
- "message": str # Message de l'exception
100
- }
101
- """
102
- try:
103
- query = "SELECT tag_name FROM tags ORDER BY tag_name"
104
- result = con.execute(query).fetchall()
105
- return {"status": "ok", "result": [row[0] for row in result]}
106
- except Exception as e:
107
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
108
-
109
- def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
110
- """
111
- Récupère les articles associés à un ou plusieurs tags.
112
-
113
- Args:
114
- tags (List[str]): Une liste de noms de tags pour filtrer les articles.
115
-
116
- Returns:
117
- Dict: Un dictionnaire contenant le statut et les résultats.
118
- - Si succès :
119
- {
120
- "status": "ok",
121
- "result": List[Dict] # Liste de dictionnaires représentant les articles
122
- }
123
- Chaque dictionnaire contient les clés :
124
- - 'article_id': int, ID de l'article
125
- - 'article_title': str, Titre de l'article
126
- - 'article_url': str, URL de l'article
127
- - En cas d'erreur ou si aucun tag fourni :
128
- {
129
- "status": "error",
130
- "code": str, # Code d'erreur ou nom de l'exception
131
- "message": str # Message d'erreur
132
- }
133
-
134
- Notes:
135
- - Si la liste `tags` est vide, la fonction retourne une liste vide.
136
- - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
137
- """
138
- if not tags:
139
- return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
140
-
141
- try:
142
- placeholders = ",".join(["?"] * len(tags))
143
- query = f"""SELECT distinct a.article_id, a.article_title, a.article_url
144
- FROM tags t, tag_article ta, articles a
145
- WHERE t.tag_id = ta.tag_id
146
- AND ta.article_id = a.article_id
147
- AND t.tag_name IN ({placeholders})
148
- """
149
- result = con.execute(query, tags).fetchdf()
150
- return {"status": "ok", "result": result.to_dict(orient="records")}
151
- except Exception as e:
152
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
153
-
154
- def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict[str, Any]:
155
- """
156
- Exécute une requête de recherche sémantique avec FAISS, puis rerank avec un cross-encoder
157
- et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB.
158
-
159
- Paramètres
160
- ----------
161
- query : str
162
- La requête texte fournie par l'utilisateur.
163
- k_model : int, optionnel (défaut = 10)
164
- Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
165
- k_cross : int, optionnel (défaut = 5)
166
- Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
167
-
168
- Retour
169
- ------
170
- Dict[str, Any]
171
- Un dictionnaire contenant :
172
- - status : "ok" si succès, sinon "error"
173
- - result : liste de résultats (si succès)
174
- - code et message : informations d'erreur (si échec)
175
- """
176
- if not query:
177
- return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
178
- try:
179
- query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
180
- distances, indices = faiss_index.search(query_vec, k_model)
181
-
182
- # Résultats FAISS
183
- faiss_ids_list = indices[0].tolist()
184
- distances_list = distances[0].tolist()
185
-
186
- # Filtrer Arrow sur les IDs trouvés
187
- filtered_table = arrow_table.filter(
188
- pc.is_in(arrow_table['faiss_id'],
189
- value_set=pa.array(faiss_ids_list))
190
- )
191
-
192
- # Convertir Arrow → pandas pour ajouter la distance
193
- df = filtered_table.to_pandas()
194
-
195
- # Ajouter la distance en gardant l'ordre faiss_ids_list
196
- distance_map = dict(zip(faiss_ids_list, distances_list))
197
- df["distance"] = df["faiss_id"].map(distance_map)
198
-
199
- # Cross-encoder
200
- top_passages = df["chunk_text"].tolist()
201
- cross_input = [(query, p) for p in top_passages]
202
- cross_scores = cross_encoder.predict(cross_input)
203
-
204
- # Rerank
205
- df["cross_score"] = cross_scores
206
- df = df.sort_values(by="cross_score", ascending=False)
207
-
208
- # Garder top k_cross
209
- df_top = df.head(k_cross)
210
-
211
- # Enregistrer dans DuckDB
212
- con.register("faiss_tmp", df_top)
213
-
214
- sql = """
215
- SELECT
216
- f.faiss_id,
217
- f.document_id,
218
- f.distance,
219
- f.cross_score,
220
- f.chunk_text,
221
- a.article_title,
222
- CASE WHEN a.article_online
223
- THEN a.article_url
224
- ELSE 'Article unavailable' END AS url,
225
- STRING_AGG(t.tag_name, ', ') AS tags
226
- FROM faiss_tmp f
227
- JOIN articles a ON f.document_id = a.article_id
228
- JOIN tag_article ta ON a.article_id = ta.article_id
229
- JOIN tags t ON ta.tag_id = t.tag_id
230
- WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
231
- GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
232
- a.article_title, a.article_online, a.article_url
233
- ORDER BY AVG(f.cross_score)
234
- """
235
-
236
- duck_res = con.execute(sql).fetchdf()
237
-
238
- # Liste finale de dictionnaires
239
- list_result = duck_res.to_dict(orient="records")
240
-
241
- return {"status": "ok", "result": list_result}
242
- except Exception as e:
243
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
 
1
+ import os
2
+ from typing import List, Dict, Any
3
+ import duckdb
4
+ import faiss
5
+ import pandas as pd
6
+ from huggingface_hub import hf_hub_download
7
+ from sentence_transformers import SentenceTransformer, CrossEncoder
8
+ import torch
9
+ from datasets import load_dataset
10
+ from dotenv import load_dotenv
11
+ import pyarrow as pa
12
+ import pyarrow.compute as pc
13
+
14
+ # Initialisations
15
+ load_dotenv()
16
+ HF_TOKEN = os.getenv('API_HF_TOKEN')
17
+
18
+ REPO_ID = "Loren/articles_database"
19
+ FAISS_REPO_ID = "Loren/articles_faiss"
20
+ FAISS_INDEX_FILE = "faiss_index.bin"
21
+ MODEL_NAME = "intfloat/multilingual-e5-small"
22
+ CROSS_ENCODER_NAME = "jinaai/jina-reranker-v2-base-multilingual"
23
+
24
+ cache_dir = "/tmp"
25
+ os.makedirs(cache_dir, exist_ok=True)
26
+ # Rediriger le cache HF globalement
27
+ os.environ["HF_HOME"] = cache_dir
28
+ os.environ["HF_DATASETS_CACHE"] = cache_dir
29
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
30
+
31
+ # Téléchargement des fichiers Parquet depuis Hugging Face
32
+ articles_parquet = hf_hub_download(
33
+ repo_id=REPO_ID,
34
+ filename="articles.parquet",
35
+ repo_type="dataset",
36
+ cache_dir=cache_dir)
37
+ tags_parquet = hf_hub_download(
38
+ repo_id=REPO_ID,
39
+ filename="tags.parquet",
40
+ repo_type="dataset",
41
+ cache_dir=cache_dir)
42
+ tag_article_parquet = hf_hub_download(
43
+ repo_id=REPO_ID,
44
+ filename="tag_article.parquet",
45
+ repo_type="dataset",
46
+ cache_dir=cache_dir)
47
+
48
+ # Connexion DuckDB en mémoire
49
+ con = duckdb.connect()
50
+
51
+ # Créer des tables DuckDB directement à partir des fichiers Parquet
52
+ con.execute(f"CREATE VIEW articles AS SELECT * FROM parquet_scan('{articles_parquet}')")
53
+ con.execute(f"CREATE VIEW tags AS SELECT * FROM parquet_scan('{tags_parquet}')")
54
+ con.execute(f"CREATE VIEW tag_article AS SELECT * FROM parquet_scan('{tag_article_parquet}')")
55
+
56
+ # Téléchargement des fichiers de la base faiss depuis le dataset Hugging Face
57
+ hf_faiss_index = hf_hub_download(
58
+ repo_id=FAISS_REPO_ID,
59
+ filename=FAISS_INDEX_FILE,
60
+ repo_type="dataset",
61
+ token=HF_TOKEN,
62
+ cache_dir=cache_dir
63
+ )
64
+
65
+ # Chargement de l’index FAISS
66
+ faiss_index = faiss.read_index(hf_faiss_index)
67
+
68
+ # Téléchargement des metadatas Faiss depuis le dataset Hugging Face
69
+ dataset = load_dataset(FAISS_REPO_ID, split="train")
70
+ arrow_table = dataset.data
71
+
72
+ # Creation du Sentence transformer model
73
+ device = "cuda" if torch.cuda.is_available() else "cpu"
74
+ print(f"*** Device: {device}")
75
+ model = SentenceTransformer(MODEL_NAME, device=device)
76
+
77
+ # Création du cross-encoder
78
+ cross_encoder = CrossEncoder(CROSS_ENCODER_NAME, device=device,
79
+ trust_remote_code=True)
80
+
81
+
82
+ # Fonctions d'accès aux données
83
+
84
+ def fetch_tags() -> List[str]:
85
+ """
86
+ Récupère la liste de tous les tags disponibles dans la base de données.
87
+
88
+ Returns:
89
+ Dict: Un dictionnaire contenant le statut et les résultats.
90
+ - Si succès :
91
+ {
92
+ "status": "ok",
93
+ "result": List[str] # Liste des noms de tags triés par ordre alphabétique
94
+ }
95
+ - En cas d'erreur :
96
+ {
97
+ "status": "error",
98
+ "code": str, # Nom de l'exception
99
+ "message": str # Message de l'exception
100
+ }
101
+ """
102
+ try:
103
+ query = "SELECT tag_name FROM tags ORDER BY tag_name"
104
+ result = con.execute(query).fetchall()
105
+ return {"status": "ok", "result": [row[0] for row in result]}
106
+ except Exception as e:
107
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
108
+
109
+ def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
110
+ """
111
+ Récupère les articles associés à un ou plusieurs tags.
112
+
113
+ Args:
114
+ tags (List[str]): Une liste de noms de tags pour filtrer les articles.
115
+
116
+ Returns:
117
+ Dict: Un dictionnaire contenant le statut et les résultats.
118
+ - Si succès :
119
+ {
120
+ "status": "ok",
121
+ "result": List[Dict] # Liste de dictionnaires représentant les articles
122
+ }
123
+ Chaque dictionnaire contient les clés :
124
+ - 'article_id': int, ID de l'article
125
+ - 'article_title': str, Titre de l'article
126
+ - 'article_url': str, URL de l'article
127
+ - En cas d'erreur ou si aucun tag fourni :
128
+ {
129
+ "status": "error",
130
+ "code": str, # Code d'erreur ou nom de l'exception
131
+ "message": str # Message d'erreur
132
+ }
133
+
134
+ Notes:
135
+ - Si la liste `tags` est vide, la fonction retourne une liste vide.
136
+ - Les résultats incluent uniquement les articles correspondant à au moins un des tags fournis.
137
+ """
138
+ if not tags:
139
+ return {"status": "error", "code": "no_tags", "message": "Aucun tag fourni."}
140
+
141
+ try:
142
+ placeholders = ",".join(["?"] * len(tags))
143
+ query = f"""SELECT distinct a.article_id, a.article_title, a.article_url
144
+ FROM tags t, tag_article ta, articles a
145
+ WHERE t.tag_id = ta.tag_id
146
+ AND ta.article_id = a.article_id
147
+ AND t.tag_name IN ({placeholders})
148
+ """
149
+ result = con.execute(query, tags).fetchdf()
150
+ return {"status": "ok", "result": result.to_dict(orient="records")}
151
+ except Exception as e:
152
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
153
+
154
+ def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict[str, Any]:
155
+ """
156
+ Exécute une requête de recherche sémantique avec FAISS, puis rerank avec un cross-encoder
157
+ et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB.
158
+
159
+ Paramètres
160
+ ----------
161
+ query : str
162
+ La requête texte fournie par l'utilisateur.
163
+ k_model : int, optionnel (défaut = 10)
164
+ Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
165
+ k_cross : int, optionnel (défaut = 5)
166
+ Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
167
+
168
+ Retour
169
+ ------
170
+ Dict[str, Any]
171
+ Un dictionnaire contenant :
172
+ - status : "ok" si succès, sinon "error"
173
+ - result : liste de résultats (si succès)
174
+ - code et message : informations d'erreur (si échec)
175
+ """
176
+ if not query:
177
+ return {"status": "error", "code": "no_query", "message": "Aucun query fourni."}
178
+ try:
179
+ query_vec = model.encode(["query: "+query], convert_to_numpy=True, normalize_embeddings=True)
180
+ distances, indices = faiss_index.search(query_vec, k_model)
181
+
182
+ # Résultats FAISS
183
+ faiss_ids_list = indices[0].tolist()
184
+ distances_list = distances[0].tolist()
185
+
186
+ # Filtrer Arrow sur les IDs trouvés
187
+ filtered_table = arrow_table.filter(
188
+ pc.is_in(arrow_table['faiss_id'],
189
+ value_set=pa.array(faiss_ids_list))
190
+ )
191
+
192
+ # Convertir Arrow → pandas pour ajouter la distance
193
+ df = filtered_table.to_pandas()
194
+
195
+ # Ajouter la distance en gardant l'ordre faiss_ids_list
196
+ distance_map = dict(zip(faiss_ids_list, distances_list))
197
+ df["distance"] = df["faiss_id"].map(distance_map)
198
+
199
+ # Cross-encoder
200
+ top_passages = df["chunk_text"].tolist()
201
+ cross_input = [(query, p) for p in top_passages]
202
+ cross_scores = cross_encoder.predict(cross_input)
203
+
204
+ # Rerank
205
+ df["cross_score"] = cross_scores
206
+ df = df.sort_values(by="cross_score", ascending=False)
207
+
208
+ # Garder top k_cross
209
+ df_top = df.head(k_cross)
210
+
211
+ # Enregistrer dans DuckDB
212
+ con.register("faiss_tmp", df_top)
213
+
214
+ sql = """
215
+ SELECT
216
+ f.faiss_id,
217
+ f.document_id,
218
+ f.distance,
219
+ f.cross_score,
220
+ f.chunk_text,
221
+ a.article_title,
222
+ CASE WHEN a.article_online
223
+ THEN a.article_url
224
+ ELSE 'Article unavailable' END AS url,
225
+ STRING_AGG(t.tag_name, ', ') AS tags
226
+ FROM faiss_tmp f
227
+ JOIN articles a ON f.document_id = a.article_id
228
+ JOIN tag_article ta ON a.article_id = ta.article_id
229
+ JOIN tags t ON ta.tag_id = t.tag_id
230
+ WHERE (LENGTH(article_text) - LENGTH(REPLACE(article_text, ' ', '')) + 1) >= 100
231
+ GROUP BY f.faiss_id, f.document_id, f.distance, f.cross_score, f.chunk_text,
232
+ a.article_title, a.article_online, a.article_url
233
+ ORDER BY AVG(f.cross_score)
234
+ """
235
+
236
+ duck_res = con.execute(sql).fetchdf()
237
+
238
+ # Liste finale de dictionnaires
239
+ list_result = duck_res.to_dict(orient="records")
240
+
241
+ return {"status": "ok", "result": list_result}
242
+ except Exception as e:
243
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}