Loren commited on
Commit
fe53425
·
verified ·
1 Parent(s): 1c255df

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +226 -245
app/main.py CHANGED
@@ -1,246 +1,227 @@
1
- from fastapi import FastAPI, Query
2
- from typing import List, Dict, Any
3
- from app import database
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.responses import HTMLResponse
6
-
7
- from pydantic import BaseModel
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
- import torch
10
- from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
-
12
-
13
- app = FastAPI(
14
- title="Articles API",
15
- description="API pour récupérer articles et tags depuis SQLite",
16
- version="1.0"
17
- )
18
-
19
- # Chargement du modèle génératif
20
- #MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
21
- #tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
- #model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
23
- # torch_dtype=torch.float16,
24
- # device_map="auto"
25
- # )
26
- model_id = "mistralai/Mistral-7B-Instruct-v0.2"
27
- # Charger le tokenizer
28
- tokenizer = AutoTokenizer.from_pretrained(model_id)
29
-
30
- # Config de quantization moderne (4-bit ou 8-bit)
31
- quant_config = BitsAndBytesConfig(
32
- load_in_4bit=True, # False pour int8
33
- bnb_4bit_compute_dtype=torch.float16, # dtype des calculs
34
- bnb_4bit_use_double_quant=True,
35
- bnb_4bit_quant_type="nf4"
36
- )
37
-
38
- # Charger le modèle avec la nouvelle API
39
- model = AutoModelForCausalLM.from_pretrained(
40
- model_id,
41
- quantization_config=quant_config,
42
- device_map="auto", # pour GPU auto
43
- dtype=torch.float16
44
- )
45
-
46
- # CORS pour permettre l'accès depuis le navigateur
47
- app.add_middleware(
48
- CORSMiddleware,
49
- allow_origins=["*"], # autorise toutes les origines
50
- allow_credentials=True,
51
- allow_methods=["*"],
52
- allow_headers=["*"],
53
- )
54
-
55
- @app.get("/", response_class=HTMLResponse)
56
- def home():
57
- return """
58
- <html>
59
- <head><title>Page d'accueil</title></head>
60
- <body>
61
- <h1>Welcome on the API search articles !</h1>
62
- </body>
63
- </html>
64
- """
65
-
66
- @app.get("/get_tags")
67
- def get_tags():
68
- """
69
- Récupère la liste de tous les tags disponibles via l'API.
70
-
71
- Returns:
72
- Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
73
- - Si succès :
74
- {
75
- "status": "ok",
76
- "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
77
- }
78
- - En cas d'erreur :
79
- {
80
- "status": "error",
81
- "code": str, # Nom de l'exception
82
- "message": str # Message de l'exception
83
- }
84
-
85
- Notes:
86
- - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
87
- - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
88
- """
89
- try:
90
- dict_result = database.fetch_tags()
91
- if dict_result["status"] == "ok":
92
- return {"status": "ok", "tags": dict_result["result"]}
93
- else:
94
- return dict_result
95
- except Exception as e:
96
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
97
-
98
- @app.get("/get_articles_with_tags")
99
- def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
100
- """
101
- Récupère les articles associés à une ou plusieurs tags spécifiés.
102
-
103
- Args:
104
- tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
105
- Doit contenir au moins un tag.
106
-
107
- Returns:
108
- Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
109
- - Si succès :
110
- {
111
- "status": "ok",
112
- "tags": List[str], # Tags utilisés pour filtrer
113
- "articles": List[Dict] # Liste des articles correspondants
114
- }
115
- Chaque article est un dictionnaire contenant :
116
- - 'article_id': int, ID de l'article
117
- - 'article_title': str, Titre de l'article
118
- - 'article_url': str, URL de l'article
119
- - En cas d'erreur :
120
- {
121
- "status": "error",
122
- "code": str, # Code d'erreur ou nom de l'exception
123
- "message": str # Message d'erreur
124
- }
125
-
126
- Notes:
127
- - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
128
- - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
129
- """
130
- try:
131
- dict_result = database.fetch_articles_by_tags(tags)
132
- if dict_result["status"] == "ok":
133
- return {"status": "ok",
134
- "tags": tags,
135
- "articles": dict_result["result"]}
136
- else:
137
- return dict_result
138
- except Exception as e:
139
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
140
-
141
-
142
- @app.get("/get_query_results")
143
- def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
144
- k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
145
- k_cross: int = Query(5, description="Nombre de résultats conservés après reranking"),
146
- use_rerank: bool = Query(True, description="Indique si le reranking avec cross-encoder doit être utilisé")
147
- ) -> Dict[str, Any]:
148
- """
149
- Récupère les résultats d'une requête en utilisant deux modèles de recherche.
150
-
151
- Args:
152
- query (str): La requête utilisateur pour laquelle récupérer les résultats.
153
- k_model (int, optionel): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
154
- k_cross (int, optionel): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
155
- use_rerank (bool, optionnel): Indique si le reranking avec cross-encoder doit être utilisé. Par défaut à True.
156
- Si False, on désactive complètement le cross-encoder et le rerank.
157
- Returns:
158
- Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
159
-
160
- Notes:
161
- - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
162
- - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
163
- """
164
- try:
165
- dict_result = database.fetch_query_results(query, k_model, k_cross, use_rerank)
166
- if dict_result["status"] == "ok":
167
- return {"status": "ok",
168
- "results": dict_result["result"]}
169
- else:
170
- return dict_result
171
- except Exception as e:
172
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
173
-
174
- #
175
- class QueryRequest(BaseModel):
176
- question: str
177
- use_rerank: bool = True
178
-
179
- @app.post("/get_answer")
180
- async def ask_question(request: QueryRequest):
181
- """
182
- Traite une question utilisateur en effectuant une recherche dans la base de données
183
- puis en générant une réponse à l’aide du modèle de langage (RAG).
184
-
185
- Le fonctionnement se déroule en trois étapes :
186
- 1. Extraction et nettoyage de la question utilisateur.
187
- 2. Recherche des passages pertinents dans la base de données (`fetch_query_results`).
188
- 3. Génération d’une réponse fondée sur les morceaux de texte récupérés (RAG).
189
-
190
- Paramètres
191
- ----------
192
- request : QueryRequest
193
- Objet contenant la question utilisateur sous forme de chaîne de caractères.
194
-
195
- Retour
196
- ------
197
- dict
198
- Un dictionnaire contenant :
199
- - "status" : "ok" si la requête a réussi, "error" en cas d'échec.
200
- - "results" : liste des chunks retournés par la base de données (présent seulement si status = "ok").
201
- - "answer" : réponse générée par le modèle ou message d’erreur.
202
-
203
- Exceptions
204
- ----------
205
- Toute exception survenant durant l’exécution est interceptée
206
- et retournée sous forme d’un message d’erreur dans la clé "answer".
207
-
208
- Notes
209
- -----
210
- - Si aucun chunk pertinent n'est trouvé, la fonction renvoie un message indiquant
211
- que seules les questions relatives aux articles du jeu de données peuvent être traitées.
212
- - La génération de la réponse utilise un template RAG et produit jusqu’à 500 tokens.
213
- """
214
- try:
215
- user_query = request.question.strip()
216
- use_rerank = request.use_rerank
217
- dict_result = database.fetch_query_results(user_query, k_model=10,
218
- k_cross=5, use_rerank=use_rerank)
219
- if dict_result["status"] == "ok":
220
- list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
221
- if not list_chunks:
222
- answer = ("Je ne dispose pas d’informations sur ce sujet. "
223
- "Je peux uniquement répondre à des questions sur les articles " \
224
- "du jeu de données.")
225
- else:
226
- # Construction du prompt
227
- prompt = RAG_PROMPT_TEMPLATE.format(
228
- context="\n".join(list_chunks),
229
- question=user_query
230
- )
231
- # Génération de la réponse
232
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
233
- outputs = model.generate(**inputs, max_new_tokens=500)
234
- generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
235
- answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
236
- return {"status": "ok",
237
- "results": dict_result["result"],
238
- "answer": answer}
239
- else:
240
- answer = f"Une erreur est survenue lors de la récupération des informations : \
241
- {dict_result['code']} - {dict_result['message']}."
242
- return {"status": "error", "answer": answer}
243
- except Exception as e:
244
- answer = f"Une erreur est survenue lors de la récupération des informations : \
245
- {type(e).__name__} - {str(e)}."
246
  return {"status": "error", "answer": answer}
 
1
+ from fastapi import FastAPI, Query
2
+ from typing import List, Dict, Any
3
+ from app import database
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import HTMLResponse
6
+
7
+ from pydantic import BaseModel
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ import torch
10
+ from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
+
12
+
13
+ app = FastAPI(
14
+ title="Articles API",
15
+ description="API pour récupérer articles et tags depuis SQLite",
16
+ version="1.0"
17
+ )
18
+
19
+ # Chargement du modèle génératif
20
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto"
25
+ )
26
+
27
+ # CORS pour permettre l'accès depuis le navigateur
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"], # autorise toutes les origines
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ @app.get("/", response_class=HTMLResponse)
37
+ def home():
38
+ return """
39
+ <html>
40
+ <head><title>Page d'accueil</title></head>
41
+ <body>
42
+ <h1>Welcome on the API search articles !</h1>
43
+ </body>
44
+ </html>
45
+ """
46
+
47
+ @app.get("/get_tags")
48
+ def get_tags():
49
+ """
50
+ Récupère la liste de tous les tags disponibles via l'API.
51
+
52
+ Returns:
53
+ Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
54
+ - Si succès :
55
+ {
56
+ "status": "ok",
57
+ "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
58
+ }
59
+ - En cas d'erreur :
60
+ {
61
+ "status": "error",
62
+ "code": str, # Nom de l'exception
63
+ "message": str # Message de l'exception
64
+ }
65
+
66
+ Notes:
67
+ - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
68
+ - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
69
+ """
70
+ try:
71
+ dict_result = database.fetch_tags()
72
+ if dict_result["status"] == "ok":
73
+ return {"status": "ok", "tags": dict_result["result"]}
74
+ else:
75
+ return dict_result
76
+ except Exception as e:
77
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
78
+
79
+ @app.get("/get_articles_with_tags")
80
+ def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
81
+ """
82
+ Récupère les articles associés à une ou plusieurs tags spécifiés.
83
+
84
+ Args:
85
+ tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
86
+ Doit contenir au moins un tag.
87
+
88
+ Returns:
89
+ Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
90
+ - Si succès :
91
+ {
92
+ "status": "ok",
93
+ "tags": List[str], # Tags utilisés pour filtrer
94
+ "articles": List[Dict] # Liste des articles correspondants
95
+ }
96
+ Chaque article est un dictionnaire contenant :
97
+ - 'article_id': int, ID de l'article
98
+ - 'article_title': str, Titre de l'article
99
+ - 'article_url': str, URL de l'article
100
+ - En cas d'erreur :
101
+ {
102
+ "status": "error",
103
+ "code": str, # Code d'erreur ou nom de l'exception
104
+ "message": str # Message d'erreur
105
+ }
106
+
107
+ Notes:
108
+ - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
109
+ - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
110
+ """
111
+ try:
112
+ dict_result = database.fetch_articles_by_tags(tags)
113
+ if dict_result["status"] == "ok":
114
+ return {"status": "ok",
115
+ "tags": tags,
116
+ "articles": dict_result["result"]}
117
+ else:
118
+ return dict_result
119
+ except Exception as e:
120
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
121
+
122
+
123
+ @app.get("/get_query_results")
124
+ def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
125
+ k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
126
+ k_cross: int = Query(5, description="Nombre de résultats conservés après reranking"),
127
+ use_rerank: bool = Query(True, description="Indique si le reranking avec cross-encoder doit être utilisé")
128
+ ) -> Dict[str, Any]:
129
+ """
130
+ Récupère les résultats d'une requête en utilisant deux modèles de recherche.
131
+
132
+ Args:
133
+ query (str): La requête utilisateur pour laquelle récupérer les résultats.
134
+ k_model (int, optionel): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
135
+ k_cross (int, optionel): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
136
+ use_rerank (bool, optionnel): Indique si le reranking avec cross-encoder doit être utilisé. Par défaut à True.
137
+ Si False, on désactive complètement le cross-encoder et le rerank.
138
+ Returns:
139
+ Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
140
+
141
+ Notes:
142
+ - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
143
+ - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
144
+ """
145
+ try:
146
+ dict_result = database.fetch_query_results(query, k_model, k_cross, use_rerank)
147
+ if dict_result["status"] == "ok":
148
+ return {"status": "ok",
149
+ "results": dict_result["result"]}
150
+ else:
151
+ return dict_result
152
+ except Exception as e:
153
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
154
+
155
+ #
156
+ class QueryRequest(BaseModel):
157
+ question: str
158
+ use_rerank: bool = True
159
+
160
+ @app.post("/get_answer")
161
+ async def ask_question(request: QueryRequest):
162
+ """
163
+ Traite une question utilisateur en effectuant une recherche dans la base de données
164
+ puis en générant une réponse à l’aide du modèle de langage (RAG).
165
+
166
+ Le fonctionnement se déroule en trois étapes :
167
+ 1. Extraction et nettoyage de la question utilisateur.
168
+ 2. Recherche des passages pertinents dans la base de données (`fetch_query_results`).
169
+ 3. Génération d’une réponse fondée sur les morceaux de texte récupérés (RAG).
170
+
171
+ Paramètres
172
+ ----------
173
+ request : QueryRequest
174
+ Objet contenant la question utilisateur sous forme de chaîne de caractères.
175
+
176
+ Retour
177
+ ------
178
+ dict
179
+ Un dictionnaire contenant :
180
+ - "status" : "ok" si la requête a réussi, "error" en cas d'échec.
181
+ - "results" : liste des chunks retournés par la base de données (présent seulement si status = "ok").
182
+ - "answer" : réponse générée par le modèle ou message d’erreur.
183
+
184
+ Exceptions
185
+ ----------
186
+ Toute exception survenant durant l’exécution est interceptée
187
+ et retournée sous forme d’un message d’erreur dans la clé "answer".
188
+
189
+ Notes
190
+ -----
191
+ - Si aucun chunk pertinent n'est trouvé, la fonction renvoie un message indiquant
192
+ que seules les questions relatives aux articles du jeu de données peuvent être traitées.
193
+ - La génération de la réponse utilise un template RAG et produit jusqu’à 500 tokens.
194
+ """
195
+ try:
196
+ user_query = request.question.strip()
197
+ use_rerank = request.use_rerank
198
+ dict_result = database.fetch_query_results(user_query, k_model=10,
199
+ k_cross=5, use_rerank=use_rerank)
200
+ if dict_result["status"] == "ok":
201
+ list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
202
+ if not list_chunks:
203
+ answer = ("Je ne dispose pas d’informations sur ce sujet. "
204
+ "Je peux uniquement répondre à des questions sur les articles " \
205
+ "du jeu de données.")
206
+ else:
207
+ # Construction du prompt
208
+ prompt = RAG_PROMPT_TEMPLATE.format(
209
+ context="\n".join(list_chunks),
210
+ question=user_query
211
+ )
212
+ # Génération de la réponse
213
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
214
+ outputs = model.generate(**inputs, max_new_tokens=500)
215
+ generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
216
+ answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
217
+ return {"status": "ok",
218
+ "results": dict_result["result"],
219
+ "answer": answer}
220
+ else:
221
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
222
+ {dict_result['code']} - {dict_result['message']}."
223
+ return {"status": "error", "answer": answer}
224
+ except Exception as e:
225
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
226
+ {type(e).__name__} - {str(e)}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  return {"status": "error", "answer": answer}