Loren commited on
Commit
d5c8f86
·
verified ·
1 Parent(s): d721f56

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -226
main.py DELETED
@@ -1,226 +0,0 @@
1
- from fastapi import FastAPI, Query
2
- from typing import List, Dict, Any
3
- from app import database
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.responses import HTMLResponse
6
-
7
- from pydantic import BaseModel
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
- import torch
10
- from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
-
12
- app = FastAPI(
13
- title="Articles API",
14
- description="API pour récupérer articles et tags depuis SQLite",
15
- version="1.0"
16
- )
17
-
18
- # Chargement du modèle génératif
19
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
22
- torch_dtype=torch.float16,
23
- device_map="auto"
24
- )
25
-
26
- # CORS pour permettre l'accès depuis le navigateur
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"], # autorise toutes les origines
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
- @app.get("/", response_class=HTMLResponse)
36
- def home():
37
- return """
38
- <html>
39
- <head><title>Page d'accueil</title></head>
40
- <body>
41
- <h1>Welcome on the API search articles !</h1>
42
- </body>
43
- </html>
44
- """
45
-
46
- @app.get("/get_tags")
47
- def get_tags():
48
- """
49
- Récupère la liste de tous les tags disponibles via l'API.
50
-
51
- Returns:
52
- Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
53
- - Si succès :
54
- {
55
- "status": "ok",
56
- "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
57
- }
58
- - En cas d'erreur :
59
- {
60
- "status": "error",
61
- "code": str, # Nom de l'exception
62
- "message": str # Message de l'exception
63
- }
64
-
65
- Notes:
66
- - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
67
- - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
68
- """
69
- try:
70
- dict_result = database.fetch_tags()
71
- if dict_result["status"] == "ok":
72
- return {"status": "ok", "tags": dict_result["result"]}
73
- else:
74
- return dict_result
75
- except Exception as e:
76
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
77
-
78
- @app.get("/get_articles_with_tags")
79
- def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
80
- """
81
- Récupère les articles associés à une ou plusieurs tags spécifiés.
82
-
83
- Args:
84
- tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
85
- Doit contenir au moins un tag.
86
-
87
- Returns:
88
- Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
89
- - Si succès :
90
- {
91
- "status": "ok",
92
- "tags": List[str], # Tags utilisés pour filtrer
93
- "articles": List[Dict] # Liste des articles correspondants
94
- }
95
- Chaque article est un dictionnaire contenant :
96
- - 'article_id': int, ID de l'article
97
- - 'article_title': str, Titre de l'article
98
- - 'article_url': str, URL de l'article
99
- - En cas d'erreur :
100
- {
101
- "status": "error",
102
- "code": str, # Code d'erreur ou nom de l'exception
103
- "message": str # Message d'erreur
104
- }
105
-
106
- Notes:
107
- - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
108
- - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
109
- """
110
- try:
111
- dict_result = database.fetch_articles_by_tags(tags)
112
- if dict_result["status"] == "ok":
113
- return {"status": "ok",
114
- "tags": tags,
115
- "articles": dict_result["result"]}
116
- else:
117
- return dict_result
118
- except Exception as e:
119
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
120
-
121
-
122
- @app.get("/get_query_results")
123
- def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
124
- k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
125
- k_cross: int = Query(5, description="Nombre de résultats conservés après reranking"),
126
- use_rerank: bool = Query(True, description="Indique si le reranking avec cross-encoder doit être utilisé")
127
- ) -> Dict[str, Any]:
128
- """
129
- Récupère les résultats d'une requête en utilisant deux modèles de recherche.
130
-
131
- Args:
132
- query (str): La requête utilisateur pour laquelle récupérer les résultats.
133
- k_model (int, optionel): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
134
- k_cross (int, optionel): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
135
- use_rerank (bool, optionnel): Indique si le reranking avec cross-encoder doit être utilisé. Par défaut à True.
136
- Si False, on désactive complètement le cross-encoder et le rerank.
137
- Returns:
138
- Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
139
-
140
- Notes:
141
- - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
142
- - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
143
- """
144
- try:
145
- dict_result = database.fetch_query_results(query, k_model, k_cross, use_rerank)
146
- if dict_result["status"] == "ok":
147
- return {"status": "ok",
148
- "results": dict_result["result"]}
149
- else:
150
- return dict_result
151
- except Exception as e:
152
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
153
-
154
- #
155
- class QueryRequest(BaseModel):
156
- question: str
157
- use_rerank: bool = True
158
-
159
- @app.post("/get_answer")
160
- async def ask_question(request: QueryRequest):
161
- """
162
- Traite une question utilisateur en effectuant une recherche dans la base de données
163
- puis en générant une réponse à l’aide du modèle de langage (RAG).
164
-
165
- Le fonctionnement se déroule en trois étapes :
166
- 1. Extraction et nettoyage de la question utilisateur.
167
- 2. Recherche des passages pertinents dans la base de données (`fetch_query_results`).
168
- 3. Génération d’une réponse fondée sur les morceaux de texte récupérés (RAG).
169
-
170
- Paramètres
171
- ----------
172
- request : QueryRequest
173
- Objet contenant la question utilisateur sous forme de chaîne de caractères.
174
-
175
- Retour
176
- ------
177
- dict
178
- Un dictionnaire contenant :
179
- - "status" : "ok" si la requête a réussi, "error" en cas d'échec.
180
- - "results" : liste des chunks retournés par la base de données (présent seulement si status = "ok").
181
- - "answer" : réponse générée par le modèle ou message d’erreur.
182
-
183
- Exceptions
184
- ----------
185
- Toute exception survenant durant l’exécution est interceptée
186
- et retournée sous forme d’un message d’erreur dans la clé "answer".
187
-
188
- Notes
189
- -----
190
- - Si aucun chunk pertinent n'est trouvé, la fonction renvoie un message indiquant
191
- que seules les questions relatives aux articles du jeu de données peuvent être traitées.
192
- - La génération de la réponse utilise un template RAG et produit jusqu’à 500 tokens.
193
- """
194
- try:
195
- user_query = request.question.strip()
196
- use_rerank = request.use_rerank
197
- dict_result = database.fetch_query_results(user_query, k_model=10,
198
- k_cross=5, use_rerank=use_rerank)
199
- if dict_result["status"] == "ok":
200
- list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
201
- if not list_chunks:
202
- answer = ("Je ne dispose pas d’informations sur ce sujet. "
203
- "Je peux uniquement répondre à des questions sur les articles " \
204
- "du jeu de données.")
205
- else:
206
- # Construction du prompt
207
- prompt = RAG_PROMPT_TEMPLATE.format(
208
- context="\n".join(list_chunks),
209
- question=user_query
210
- )
211
- # Génération de la réponse
212
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
213
- outputs = model.generate(**inputs, max_new_tokens=500)
214
- generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
215
- answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
216
- return {"status": "ok",
217
- "results": dict_result["result"],
218
- "answer": answer}
219
- else:
220
- answer = f"Une erreur est survenue lors de la récupération des informations : \
221
- {dict_result['code']} - {dict_result['message']}."
222
- return {"status": "error", "answer": answer}
223
- except Exception as e:
224
- answer = f"Une erreur est survenue lors de la récupération des informations : \
225
- {type(e).__name__} - {str(e)}."
226
- return {"status": "error", "answer": answer}