Loren commited on
Commit
1751d36
·
verified ·
1 Parent(s): 5715b2d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app/database.py +32 -16
  2. app/main.py +225 -188
app/database.py CHANGED
@@ -21,6 +21,9 @@ from dotenv import load_dotenv
21
  import pyarrow as pa
22
  import pyarrow.compute as pc
23
 
 
 
 
24
  # Initialisations
25
  load_dotenv()
26
  HF_TOKEN = os.getenv('API_HF_TOKEN')
@@ -161,10 +164,13 @@ def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
161
  except Exception as e:
162
  return {"status": "error", "code": type(e).__name__, "message": str(e)}
163
 
164
- def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict[str, Any]:
 
 
165
  """
166
- Exécute une requête de recherche sémantique avec FAISS, puis rerank avec un cross-encoder
167
- et retourne les meilleurs passages enrichis avec des métadonnées provenant de DuckDB.
 
168
 
169
  Paramètres
170
  ----------
@@ -174,6 +180,8 @@ def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict
174
  Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
175
  k_cross : int, optionnel (défaut = 5)
176
  Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
 
 
177
 
178
  Retour
179
  ------
@@ -206,18 +214,26 @@ def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict
206
  distance_map = dict(zip(faiss_ids_list, distances_list))
207
  df["distance"] = df["faiss_id"].map(distance_map)
208
 
209
- # Cross-encoder
210
- df["chunk_text"] = df["chunk_text"].str.replace(r'\s+', ' ', regex=True).str.strip()
211
- top_passages = df["chunk_text"].tolist()
212
- cross_input = [(query, p) for p in top_passages]
213
- cross_scores = cross_encoder.predict(cross_input)
214
-
215
- # Rerank
216
- df["cross_score"] = cross_scores
217
- df = df.sort_values(by="cross_score", ascending=False)
218
-
219
- # Garder top k_cross
220
- df_top = df.head(k_cross)
 
 
 
 
 
 
 
 
221
 
222
  # Enregistrer dans DuckDB
223
  con.register("faiss_tmp", df_top)
@@ -250,7 +266,7 @@ def fetch_query_results(query: str, k_model: int = 10, k_cross: int = 5) -> Dict
250
  # Liste finale de dictionnaires
251
  list_result = duck_res.to_dict(orient="records")
252
 
253
- return {"status": "ok", "result": list_result}
254
  except Exception as e:
255
  return {"status": "error", "code": type(e).__name__, "message": str(e)}
256
 
 
21
  import pyarrow as pa
22
  import pyarrow.compute as pc
23
 
24
+ import logging
25
+ logging.basicConfig(level=logging.DEBUG)
26
+
27
  # Initialisations
28
  load_dotenv()
29
  HF_TOKEN = os.getenv('API_HF_TOKEN')
 
164
  except Exception as e:
165
  return {"status": "error", "code": type(e).__name__, "message": str(e)}
166
 
167
+ def fetch_query_results(query: str, k_model: int = 10,
168
+ k_cross: int = 5, use_rerank: bool = True
169
+ ) -> Dict[str, Any]:
170
  """
171
+ Exécute une requête de recherche sémantique avec FAISS, puis (optionnellement)
172
+ rerank avec un cross-encoder et retourne les meilleurs passages enrichis avec
173
+ des métadonnées provenant de DuckDB.
174
 
175
  Paramètres
176
  ----------
 
180
  Nombre de résultats les plus proches à récupérer depuis l'index FAISS.
181
  k_cross : int, optionnel (défaut = 5)
182
  Nombre de résultats finaux à conserver après reranking avec le cross-encoder.
183
+ use_rerank : bool, optionnel (défaut = True)
184
+ Si False, on désactive complètement le cross-encoder et le rerank.
185
 
186
  Retour
187
  ------
 
214
  distance_map = dict(zip(faiss_ids_list, distances_list))
215
  df["distance"] = df["faiss_id"].map(distance_map)
216
 
217
+ if use_rerank:
218
+ status_dbg = "ok_rerank"
219
+ # Cross-encoder
220
+ df["chunk_text"] = df["chunk_text"].str.replace(r'\s+', ' ', regex=True).str.strip()
221
+ top_passages = df["chunk_text"].tolist()
222
+ cross_input = [(query, p) for p in top_passages]
223
+ cross_scores = cross_encoder.predict(cross_input)
224
+
225
+ # Rerank
226
+ df["cross_score"] = cross_scores
227
+ df = df.sort_values(by="cross_score", ascending=False)
228
+
229
+ # Garder top k_cross
230
+ df_top = df.head(k_cross)
231
+ else:
232
+ status_dbg = "ok_no_rerank"
233
+ df = df.sort_values(by="distance", ascending=False)
234
+ df["cross_score"] = df["distance"]
235
+ # Garder top k_model
236
+ df_top = df.head(k_model)
237
 
238
  # Enregistrer dans DuckDB
239
  con.register("faiss_tmp", df_top)
 
266
  # Liste finale de dictionnaires
267
  list_result = duck_res.to_dict(orient="records")
268
 
269
+ return {"status": status_dbg, "result": list_result}
270
  except Exception as e:
271
  return {"status": "error", "code": type(e).__name__, "message": str(e)}
272
 
app/main.py CHANGED
@@ -1,189 +1,226 @@
1
- from fastapi import FastAPI, Query
2
- from typing import List, Dict, Any
3
- from app import database
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.responses import HTMLResponse
6
-
7
- from pydantic import BaseModel
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
- import torch
10
- from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
-
12
- app = FastAPI(
13
- title="Articles API",
14
- description="API pour récupérer articles et tags depuis SQLite",
15
- version="1.0"
16
- )
17
-
18
- # Chargement du modèle génératif
19
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
22
- torch_dtype=torch.float16,
23
- device_map="auto"
24
- )
25
-
26
- # CORS pour permettre l'accès depuis le navigateur
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"], # autorise toutes les origines
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
- @app.get("/", response_class=HTMLResponse)
36
- def home():
37
- return """
38
- <html>
39
- <head><title>Page d'accueil</title></head>
40
- <body>
41
- <h1>Welcome on the API search articles !</h1>
42
- </body>
43
- </html>
44
- """
45
-
46
- @app.get("/get_tags")
47
- def get_tags():
48
- """
49
- Récupère la liste de tous les tags disponibles via l'API.
50
-
51
- Returns:
52
- Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
53
- - Si succès :
54
- {
55
- "status": "ok",
56
- "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
57
- }
58
- - En cas d'erreur :
59
- {
60
- "status": "error",
61
- "code": str, # Nom de l'exception
62
- "message": str # Message de l'exception
63
- }
64
-
65
- Notes:
66
- - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
67
- - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
68
- """
69
- try:
70
- dict_result = database.fetch_tags()
71
- if dict_result["status"] == "ok":
72
- return {"status": "ok", "tags": dict_result["result"]}
73
- else:
74
- return dict_result
75
- except Exception as e:
76
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
77
-
78
- @app.get("/get_articles_with_tags")
79
- def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
80
- """
81
- Récupère les articles associés à une ou plusieurs tags spécifiés.
82
-
83
- Args:
84
- tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
85
- Doit contenir au moins un tag.
86
-
87
- Returns:
88
- Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
89
- - Si succès :
90
- {
91
- "status": "ok",
92
- "tags": List[str], # Tags utilisés pour filtrer
93
- "articles": List[Dict] # Liste des articles correspondants
94
- }
95
- Chaque article est un dictionnaire contenant :
96
- - 'article_id': int, ID de l'article
97
- - 'article_title': str, Titre de l'article
98
- - 'article_url': str, URL de l'article
99
- - En cas d'erreur :
100
- {
101
- "status": "error",
102
- "code": str, # Code d'erreur ou nom de l'exception
103
- "message": str # Message d'erreur
104
- }
105
-
106
- Notes:
107
- - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
108
- - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
109
- """
110
- try:
111
- dict_result = database.fetch_articles_by_tags(tags)
112
- if dict_result["status"] == "ok":
113
- return {"status": "ok",
114
- "tags": tags,
115
- "articles": dict_result["result"]}
116
- else:
117
- return dict_result
118
- except Exception as e:
119
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
120
-
121
-
122
- @app.get("/get_query_results")
123
- def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
124
- k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
125
- k_cross: int = Query(5, description="Nombre de résultats conservés après reranking")
126
- ) -> Dict[str, Any]:
127
- """
128
- Récupère les résultats d'une requête en utilisant deux modèles de recherche.
129
-
130
- Args:
131
- query (str): La requête utilisateur pour laquelle récupérer les résultats.
132
- k_model (int, optional): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
133
- k_cross (int, optional): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
134
-
135
- Returns:
136
- Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
137
-
138
- Notes:
139
- - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
140
- - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
141
- """
142
- try:
143
- dict_result = database.fetch_query_results(query, k_model, k_cross)
144
- if dict_result["status"] == "ok":
145
- return {"status": "ok",
146
- "results": dict_result["result"]}
147
- else:
148
- return dict_result
149
- except Exception as e:
150
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
151
-
152
- # 🔹 Exemple de modèle d'entrée utilisateur
153
- class QueryRequest(BaseModel):
154
- question: str
155
-
156
- @app.post("/ask")
157
- async def ask_question(request: QueryRequest):
158
- try:
159
- user_query = request.question.strip()
160
- dict_result = database.fetch_query_results(user_query, k_model=10, k_cross=5)
161
- if dict_result["status"] == "ok":
162
- list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
163
- if not list_chunks:
164
- answer = ("Je ne dispose pas d’informations sur ce sujet. "
165
- "Je peux uniquement répondre à des questions sur les articles " \
166
- "du jeu de données.")
167
- else:
168
- # Construction du prompt
169
- prompt = RAG_PROMPT_TEMPLATE.format(
170
- context="\n".join(list_chunks),
171
- question=user_query
172
- )
173
- print("*** Prompt : ", prompt)
174
- # Génération de la réponse
175
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
176
- outputs = model.generate(**inputs, max_new_tokens=500)
177
- generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
178
- answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
179
- return {"status": "ok",
180
- "results": dict_result["result"],
181
- "answer": answer}
182
- else:
183
- answer = f"Une erreur est survenue lors de la récupération des informations : \
184
- {dict_result['code']} - {dict_result['message']}."
185
- return {"status": "error", "answer": answer}
186
- except Exception as e:
187
- answer = f"Une erreur est survenue lors de la récupération des informations : \
188
- {type(e).__name__} - {str(e)}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  return {"status": "error", "answer": answer}
 
1
+ from fastapi import FastAPI, Query
2
+ from typing import List, Dict, Any
3
+ from app import database
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import HTMLResponse
6
+
7
+ from pydantic import BaseModel
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import torch
10
+ from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
+
12
+ app = FastAPI(
13
+ title="Articles API",
14
+ description="API pour récupérer articles et tags depuis SQLite",
15
+ version="1.0"
16
+ )
17
+
18
+ # Chargement du modèle génératif
19
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto"
24
+ )
25
+
26
+ # CORS pour permettre l'accès depuis le navigateur
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # autorise toutes les origines
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ @app.get("/", response_class=HTMLResponse)
36
+ def home():
37
+ return """
38
+ <html>
39
+ <head><title>Page d'accueil</title></head>
40
+ <body>
41
+ <h1>Welcome on the API search articles !</h1>
42
+ </body>
43
+ </html>
44
+ """
45
+
46
+ @app.get("/get_tags")
47
+ def get_tags():
48
+ """
49
+ Récupère la liste de tous les tags disponibles via l'API.
50
+
51
+ Returns:
52
+ Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
53
+ - Si succès :
54
+ {
55
+ "status": "ok",
56
+ "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
57
+ }
58
+ - En cas d'erreur :
59
+ {
60
+ "status": "error",
61
+ "code": str, # Nom de l'exception
62
+ "message": str # Message de l'exception
63
+ }
64
+
65
+ Notes:
66
+ - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
67
+ - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
68
+ """
69
+ try:
70
+ dict_result = database.fetch_tags()
71
+ if dict_result["status"] == "ok":
72
+ return {"status": "ok", "tags": dict_result["result"]}
73
+ else:
74
+ return dict_result
75
+ except Exception as e:
76
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
77
+
78
+ @app.get("/get_articles_with_tags")
79
+ def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
80
+ """
81
+ Récupère les articles associés à une ou plusieurs tags spécifiés.
82
+
83
+ Args:
84
+ tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
85
+ Doit contenir au moins un tag.
86
+
87
+ Returns:
88
+ Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
89
+ - Si succès :
90
+ {
91
+ "status": "ok",
92
+ "tags": List[str], # Tags utilisés pour filtrer
93
+ "articles": List[Dict] # Liste des articles correspondants
94
+ }
95
+ Chaque article est un dictionnaire contenant :
96
+ - 'article_id': int, ID de l'article
97
+ - 'article_title': str, Titre de l'article
98
+ - 'article_url': str, URL de l'article
99
+ - En cas d'erreur :
100
+ {
101
+ "status": "error",
102
+ "code": str, # Code d'erreur ou nom de l'exception
103
+ "message": str # Message d'erreur
104
+ }
105
+
106
+ Notes:
107
+ - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
108
+ - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
109
+ """
110
+ try:
111
+ dict_result = database.fetch_articles_by_tags(tags)
112
+ if dict_result["status"] == "ok":
113
+ return {"status": "ok",
114
+ "tags": tags,
115
+ "articles": dict_result["result"]}
116
+ else:
117
+ return dict_result
118
+ except Exception as e:
119
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
120
+
121
+
122
+ @app.get("/get_query_results")
123
+ def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
124
+ k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
125
+ k_cross: int = Query(5, description="Nombre de résultats conservés après reranking"),
126
+ use_rerank: bool = Query(True, description="Indique si le reranking avec cross-encoder doit être utilisé")
127
+ ) -> Dict[str, Any]:
128
+ """
129
+ Récupère les résultats d'une requête en utilisant deux modèles de recherche.
130
+
131
+ Args:
132
+ query (str): La requête utilisateur pour laquelle récupérer les résultats.
133
+ k_model (int, optionel): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
134
+ k_cross (int, optionel): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
135
+ use_rerank (bool, optionnel): Indique si le reranking avec cross-encoder doit être utilisé. Par défaut à True.
136
+ Si False, on désactive complètement le cross-encoder et le rerank.
137
+ Returns:
138
+ Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
139
+
140
+ Notes:
141
+ - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
142
+ - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
143
+ """
144
+ try:
145
+ dict_result = database.fetch_query_results(query, k_model, k_cross, use_rerank)
146
+ if dict_result["status"] == "ok":
147
+ return {"status": "ok",
148
+ "results": dict_result["result"]}
149
+ else:
150
+ return dict_result
151
+ except Exception as e:
152
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
153
+
154
+ #
155
+ class QueryRequest(BaseModel):
156
+ question: str
157
+ use_rerank: bool = True
158
+
159
+ @app.post("/get_answer")
160
+ async def ask_question(request: QueryRequest):
161
+ """
162
+ Traite une question utilisateur en effectuant une recherche dans la base de données
163
+ puis en générant une réponse à l’aide du modèle de langage (RAG).
164
+
165
+ Le fonctionnement se déroule en trois étapes :
166
+ 1. Extraction et nettoyage de la question utilisateur.
167
+ 2. Recherche des passages pertinents dans la base de données (`fetch_query_results`).
168
+ 3. Génération d’une réponse fondée sur les morceaux de texte récupérés (RAG).
169
+
170
+ Paramètres
171
+ ----------
172
+ request : QueryRequest
173
+ Objet contenant la question utilisateur sous forme de chaîne de caractères.
174
+
175
+ Retour
176
+ ------
177
+ dict
178
+ Un dictionnaire contenant :
179
+ - "status" : "ok" si la requête a réussi, "error" en cas d'échec.
180
+ - "results" : liste des chunks retournés par la base de données (présent seulement si status = "ok").
181
+ - "answer" : réponse générée par le modèle ou message d’erreur.
182
+
183
+ Exceptions
184
+ ----------
185
+ Toute exception survenant durant l’exécution est interceptée
186
+ et retournée sous forme d’un message d’erreur dans la clé "answer".
187
+
188
+ Notes
189
+ -----
190
+ - Si aucun chunk pertinent n'est trouvé, la fonction renvoie un message indiquant
191
+ que seules les questions relatives aux articles du jeu de données peuvent être traitées.
192
+ - La génération de la réponse utilise un template RAG et produit jusqu’à 500 tokens.
193
+ """
194
+ try:
195
+ user_query = request.question.strip()
196
+ use_rerank = request.use_rerank
197
+ dict_result = database.fetch_query_results(user_query, k_model=10,
198
+ k_cross=5, use_rerank=use_rerank)
199
+ if dict_result["status"] == "ok":
200
+ list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
201
+ if not list_chunks:
202
+ answer = ("Je ne dispose pas d’informations sur ce sujet. "
203
+ "Je peux uniquement répondre à des questions sur les articles " \
204
+ "du jeu de données.")
205
+ else:
206
+ # Construction du prompt
207
+ prompt = RAG_PROMPT_TEMPLATE.format(
208
+ context="\n".join(list_chunks),
209
+ question=user_query
210
+ )
211
+ # Génération de la réponse
212
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
213
+ outputs = model.generate(**inputs, max_new_tokens=500)
214
+ generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
215
+ answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
216
+ return {"status": "ok",
217
+ "results": dict_result["result"],
218
+ "answer": answer}
219
+ else:
220
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
221
+ {dict_result['code']} - {dict_result['message']}."
222
+ return {"status": "error", "answer": answer}
223
+ except Exception as e:
224
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
225
+ {type(e).__name__} - {str(e)}."
226
  return {"status": "error", "answer": answer}