Loren commited on
Commit
9a517f4
·
verified ·
1 Parent(s): c4e2e17

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +188 -187
app/main.py CHANGED
@@ -1,188 +1,189 @@
1
- from fastapi import FastAPI, Query
2
- from typing import List, Dict, Any
3
- from app import database
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.responses import HTMLResponse
6
-
7
- from pydantic import BaseModel
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
- import torch
10
- from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
-
12
- app = FastAPI(
13
- title="Articles API",
14
- description="API pour récupérer articles et tags depuis SQLite",
15
- version="1.0"
16
- )
17
-
18
- # Chargement du modèle génératif
19
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
22
- torch_dtype=torch.float16,
23
- device_map="auto"
24
- )
25
-
26
- # CORS pour permettre l'accès depuis le navigateur
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"], # autorise toutes les origines
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
- @app.get("/", response_class=HTMLResponse)
36
- def home():
37
- return """
38
- <html>
39
- <head><title>Page d'accueil</title></head>
40
- <body>
41
- <h1>Welcome on the API search articles !</h1>
42
- </body>
43
- </html>
44
- """
45
-
46
- @app.get("/get_tags")
47
- def get_tags():
48
- """
49
- Récupère la liste de tous les tags disponibles via l'API.
50
-
51
- Returns:
52
- Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
53
- - Si succès :
54
- {
55
- "status": "ok",
56
- "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
57
- }
58
- - En cas d'erreur :
59
- {
60
- "status": "error",
61
- "code": str, # Nom de l'exception
62
- "message": str # Message de l'exception
63
- }
64
-
65
- Notes:
66
- - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
67
- - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
68
- """
69
- try:
70
- dict_result = database.fetch_tags()
71
- if dict_result["status"] == "ok":
72
- return {"status": "ok", "tags": dict_result["result"]}
73
- else:
74
- return dict_result
75
- except Exception as e:
76
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
77
-
78
- @app.get("/get_articles_with_tags")
79
- def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
80
- """
81
- Récupère les articles associés à une ou plusieurs tags spécifiés.
82
-
83
- Args:
84
- tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
85
- Doit contenir au moins un tag.
86
-
87
- Returns:
88
- Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
89
- - Si succès :
90
- {
91
- "status": "ok",
92
- "tags": List[str], # Tags utilisés pour filtrer
93
- "articles": List[Dict] # Liste des articles correspondants
94
- }
95
- Chaque article est un dictionnaire contenant :
96
- - 'article_id': int, ID de l'article
97
- - 'article_title': str, Titre de l'article
98
- - 'article_url': str, URL de l'article
99
- - En cas d'erreur :
100
- {
101
- "status": "error",
102
- "code": str, # Code d'erreur ou nom de l'exception
103
- "message": str # Message d'erreur
104
- }
105
-
106
- Notes:
107
- - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
108
- - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
109
- """
110
- try:
111
- dict_result = database.fetch_articles_by_tags(tags)
112
- if dict_result["status"] == "ok":
113
- return {"status": "ok",
114
- "tags": tags,
115
- "articles": dict_result["result"]}
116
- else:
117
- return dict_result
118
- except Exception as e:
119
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
120
-
121
-
122
- @app.get("/get_query_results")
123
- def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
124
- k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
125
- k_cross: int = Query(5, description="Nombre de résultats conservés après reranking")
126
- ) -> Dict[str, Any]:
127
- """
128
- Récupère les résultats d'une requête en utilisant deux modèles de recherche.
129
-
130
- Args:
131
- query (str): La requête utilisateur pour laquelle récupérer les résultats.
132
- k_model (int, optional): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
133
- k_cross (int, optional): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
134
-
135
- Returns:
136
- Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
137
-
138
- Notes:
139
- - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
140
- - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
141
- """
142
- try:
143
- dict_result = database.fetch_query_results(query, k_model, k_cross)
144
- if dict_result["status"] == "ok":
145
- return {"status": "ok",
146
- "results": dict_result["result"]}
147
- else:
148
- return dict_result
149
- except Exception as e:
150
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
151
-
152
- # 🔹 Exemple de modèle d'entrée utilisateur
153
- class QueryRequest(BaseModel):
154
- question: str
155
-
156
- @app.post("/ask")
157
- async def ask_question(request: QueryRequest):
158
- try:
159
- user_query = request.question.strip()
160
- dict_result = database.fetch_query_results(user_query, k_model=10, k_cross=5)
161
- if dict_result["status"] == "ok":
162
- list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
163
- if not list_chunks:
164
- answer = ("Je ne dispose pas d’informations sur ce sujet. "
165
- "Je peux uniquement répondre à des questions sur les articles " \
166
- "du jeu de données.")
167
- else:
168
- # Construction du prompt
169
- prompt = RAG_PROMPT_TEMPLATE.format(
170
- context="\n".join(list_chunks),
171
- question=user_query
172
- )
173
- # Génération de la réponse
174
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
175
- outputs = model.generate(**inputs, max_new_tokens=500)
176
- generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
177
- answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
178
- return {"status": "ok",
179
- "results": dict_result["result"],
180
- "answer": answer}
181
- else:
182
- answer = f"Une erreur est survenue lors de la récupération des informations : \
183
- {dict_result['code']} - {dict_result['message']}."
184
- return {"status": "error", "answer": answer}
185
- except Exception as e:
186
- answer = f"Une erreur est survenue lors de la récupération des informations : \
187
- {type(e).__name__} - {str(e)}."
 
188
  return {"status": "error", "answer": answer}
 
1
+ from fastapi import FastAPI, Query
2
+ from typing import List, Dict, Any
3
+ from app import database
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import HTMLResponse
6
+
7
+ from pydantic import BaseModel
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import torch
10
+ from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
+
12
+ app = FastAPI(
13
+ title="Articles API",
14
+ description="API pour récupérer articles et tags depuis SQLite",
15
+ version="1.0"
16
+ )
17
+
18
+ # Chargement du modèle génératif
19
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto"
24
+ )
25
+
26
+ # CORS pour permettre l'accès depuis le navigateur
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # autorise toutes les origines
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ @app.get("/", response_class=HTMLResponse)
36
+ def home():
37
+ return """
38
+ <html>
39
+ <head><title>Page d'accueil</title></head>
40
+ <body>
41
+ <h1>Welcome on the API search articles !</h1>
42
+ </body>
43
+ </html>
44
+ """
45
+
46
+ @app.get("/get_tags")
47
+ def get_tags():
48
+ """
49
+ Récupère la liste de tous les tags disponibles via l'API.
50
+
51
+ Returns:
52
+ Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
53
+ - Si succès :
54
+ {
55
+ "status": "ok",
56
+ "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
57
+ }
58
+ - En cas d'erreur :
59
+ {
60
+ "status": "error",
61
+ "code": str, # Nom de l'exception
62
+ "message": str # Message de l'exception
63
+ }
64
+
65
+ Notes:
66
+ - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
67
+ - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
68
+ """
69
+ try:
70
+ dict_result = database.fetch_tags()
71
+ if dict_result["status"] == "ok":
72
+ return {"status": "ok", "tags": dict_result["result"]}
73
+ else:
74
+ return dict_result
75
+ except Exception as e:
76
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
77
+
78
+ @app.get("/get_articles_with_tags")
79
+ def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
80
+ """
81
+ Récupère les articles associés à une ou plusieurs tags spécifiés.
82
+
83
+ Args:
84
+ tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
85
+ Doit contenir au moins un tag.
86
+
87
+ Returns:
88
+ Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
89
+ - Si succès :
90
+ {
91
+ "status": "ok",
92
+ "tags": List[str], # Tags utilisés pour filtrer
93
+ "articles": List[Dict] # Liste des articles correspondants
94
+ }
95
+ Chaque article est un dictionnaire contenant :
96
+ - 'article_id': int, ID de l'article
97
+ - 'article_title': str, Titre de l'article
98
+ - 'article_url': str, URL de l'article
99
+ - En cas d'erreur :
100
+ {
101
+ "status": "error",
102
+ "code": str, # Code d'erreur ou nom de l'exception
103
+ "message": str # Message d'erreur
104
+ }
105
+
106
+ Notes:
107
+ - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
108
+ - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
109
+ """
110
+ try:
111
+ dict_result = database.fetch_articles_by_tags(tags)
112
+ if dict_result["status"] == "ok":
113
+ return {"status": "ok",
114
+ "tags": tags,
115
+ "articles": dict_result["result"]}
116
+ else:
117
+ return dict_result
118
+ except Exception as e:
119
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
120
+
121
+
122
+ @app.get("/get_query_results")
123
+ def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
124
+ k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
125
+ k_cross: int = Query(5, description="Nombre de résultats conservés après reranking")
126
+ ) -> Dict[str, Any]:
127
+ """
128
+ Récupère les résultats d'une requête en utilisant deux modèles de recherche.
129
+
130
+ Args:
131
+ query (str): La requête utilisateur pour laquelle récupérer les résultats.
132
+ k_model (int, optional): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
133
+ k_cross (int, optional): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
134
+
135
+ Returns:
136
+ Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
137
+
138
+ Notes:
139
+ - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
140
+ - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
141
+ """
142
+ try:
143
+ dict_result = database.fetch_query_results(query, k_model, k_cross)
144
+ if dict_result["status"] == "ok":
145
+ return {"status": "ok",
146
+ "results": dict_result["result"]}
147
+ else:
148
+ return dict_result
149
+ except Exception as e:
150
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
151
+
152
+ # 🔹 Exemple de modèle d'entrée utilisateur
153
+ class QueryRequest(BaseModel):
154
+ question: str
155
+
156
+ @app.post("/ask")
157
+ async def ask_question(request: QueryRequest):
158
+ try:
159
+ user_query = request.question.strip()
160
+ dict_result = database.fetch_query_results(user_query, k_model=10, k_cross=5)
161
+ if dict_result["status"] == "ok":
162
+ list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
163
+ if not list_chunks:
164
+ answer = ("Je ne dispose pas d’informations sur ce sujet. "
165
+ "Je peux uniquement répondre à des questions sur les articles " \
166
+ "du jeu de données.")
167
+ else:
168
+ # Construction du prompt
169
+ prompt = RAG_PROMPT_TEMPLATE.format(
170
+ context="\n".join(list_chunks),
171
+ question=user_query
172
+ )
173
+ print("*** Prompt : ", prompt)
174
+ # Génération de la réponse
175
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
176
+ outputs = model.generate(**inputs, max_new_tokens=500)
177
+ generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
178
+ answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
179
+ return {"status": "ok",
180
+ "results": dict_result["result"],
181
+ "answer": answer}
182
+ else:
183
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
184
+ {dict_result['code']} - {dict_result['message']}."
185
+ return {"status": "error", "answer": answer}
186
+ except Exception as e:
187
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
188
+ {type(e).__name__} - {str(e)}."
189
  return {"status": "error", "answer": answer}