Loren commited on
Commit
cb326b7
·
verified ·
1 Parent(s): 603f3ab

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +172 -172
app/main.py CHANGED
@@ -1,173 +1,173 @@
1
- from fastapi import FastAPI, Query
2
- from typing import List, Optional, Dict, Any
3
- from app import database
4
- from fastapi.middleware.cors import CORSMiddleware
5
-
6
- from pydantic import BaseModel
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
- import torch
9
- from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
10
-
11
- app = FastAPI(
12
- title="Articles API",
13
- description="API pour récupérer articles et tags depuis SQLite",
14
- version="1.0"
15
- )
16
-
17
- # Chargement du modèle génératif
18
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
21
- torch_dtype=torch.float16,
22
- device_map="auto"
23
- )
24
-
25
- # CORS pour permettre l'accès depuis le navigateur
26
- app.add_middleware(
27
- CORSMiddleware,
28
- allow_origins=["*"], # autorise toutes les origines
29
- allow_credentials=True,
30
- allow_methods=["*"],
31
- allow_headers=["*"],
32
- )
33
-
34
- @app.get("/get_tags")
35
- def get_tags():
36
- """
37
- Récupère la liste de tous les tags disponibles via l'API.
38
-
39
- Returns:
40
- Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
41
- - Si succès :
42
- {
43
- "status": "ok",
44
- "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
45
- }
46
- - En cas d'erreur :
47
- {
48
- "status": "error",
49
- "code": str, # Nom de l'exception
50
- "message": str # Message de l'exception
51
- }
52
-
53
- Notes:
54
- - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
55
- - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
56
- """
57
- try:
58
- dict_result = database.fetch_tags()
59
- if dict_result["status"] == "ok":
60
- return {"status": "ok", "tags": dict_result["result"]}
61
- else:
62
- return dict_result
63
- except Exception as e:
64
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
65
-
66
- @app.get("/get_articles_with_tags")
67
- def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
68
- """
69
- Récupère les articles associés à une ou plusieurs tags spécifiés.
70
-
71
- Args:
72
- tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
73
- Doit contenir au moins un tag.
74
-
75
- Returns:
76
- Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
77
- - Si succès :
78
- {
79
- "status": "ok",
80
- "tags": List[str], # Tags utilisés pour filtrer
81
- "articles": List[Dict] # Liste des articles correspondants
82
- }
83
- Chaque article est un dictionnaire contenant :
84
- - 'article_id': int, ID de l'article
85
- - 'article_title': str, Titre de l'article
86
- - 'article_url': str, URL de l'article
87
- - En cas d'erreur :
88
- {
89
- "status": "error",
90
- "code": str, # Code d'erreur ou nom de l'exception
91
- "message": str # Message d'erreur
92
- }
93
-
94
- Notes:
95
- - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
96
- - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
97
- """
98
- try:
99
- dict_result = database.fetch_articles_by_tags(tags)
100
- if dict_result["status"] == "ok":
101
- return {"status": "ok",
102
- "tags": tags,
103
- "articles": dict_result["result"]}
104
- else:
105
- return dict_result
106
- except Exception as e:
107
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
108
-
109
-
110
- @app.get("/get_query_results")
111
- def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
112
- k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
113
- k_cross: int = Query(5, description="Nombre de résultats conservés après reranking")
114
- ) -> Dict[str, Any]:
115
- """
116
- Récupère les résultats d'une requête en utilisant deux modèles de recherche.
117
-
118
- Args:
119
- query (str): La requête utilisateur pour laquelle récupérer les résultats.
120
- k_model (int, optional): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
121
- k_cross (int, optional): Nombre de résultats �� retourner pour le modèle croisé. Par défaut à 5.
122
-
123
- Returns:
124
- Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
125
-
126
- Notes:
127
- - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
128
- - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
129
- """
130
- try:
131
- dict_result = database.fetch_query_results(query, k_model, k_cross)
132
- if dict_result["status"] == "ok":
133
- return {"status": "ok",
134
- "results": dict_result["result"]}
135
- else:
136
- return dict_result
137
- except Exception as e:
138
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
139
-
140
- # 🔹 Exemple de modèle d'entrée utilisateur
141
- class QueryRequest(BaseModel):
142
- question: str
143
-
144
- @app.post("/ask")
145
- async def ask_question(request: QueryRequest):
146
- try:
147
- user_query = request.question.strip()
148
- dict_result = database.fetch_query_results(user_query, k_model=10, k_cross=5)
149
- if dict_result["status"] == "ok":
150
- list_chunks = [resp['chunk_text'] for resp in dict_result['results']]
151
- if not list_chunks:
152
- answer = ("Je ne dispose pas d’informations sur ce sujet. "
153
- "Je peux uniquement répondre à des questions sur les articles " \
154
- "du jeu de données.")
155
- else:
156
- # Construction du prompt
157
- prompt = RAG_PROMPT_TEMPLATE.format(
158
- context="\n".join(list_chunks),
159
- question=user_query
160
- )
161
- # Génération de la réponse
162
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
163
- outputs = model.generate(**inputs, max_new_tokens=500)
164
- generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
165
- answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
166
- else:
167
- answer = f"Une erreur est survenue lors de la récupération des informations : \
168
- {dict_result['code']} - {dict_result['message']}."
169
- return {"answer": answer}
170
- except Exception as e:
171
- answer = f"Une erreur est survenue lors de la récupération des informations : \
172
- {type(e).__name__} - {str(e)}."
173
  return {"answer": answer}
 
1
+ from fastapi import FastAPI, Query
2
+ from typing import List, Optional, Dict, Any
3
+ from app import database
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+
6
+ from pydantic import BaseModel
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import torch
9
+ from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
10
+
11
+ app = FastAPI(
12
+ title="Articles API",
13
+ description="API pour récupérer articles et tags depuis SQLite",
14
+ version="1.0"
15
+ )
16
+
17
+ # Chargement du modèle génératif
18
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
21
+ torch_dtype=torch.float16,
22
+ device_map="auto"
23
+ )
24
+
25
+ # CORS pour permettre l'accès depuis le navigateur
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"], # autorise toutes les origines
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ @app.get("/get_tags")
35
+ def get_tags():
36
+ """
37
+ Récupère la liste de tous les tags disponibles via l'API.
38
+
39
+ Returns:
40
+ Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
41
+ - Si succès :
42
+ {
43
+ "status": "ok",
44
+ "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
45
+ }
46
+ - En cas d'erreur :
47
+ {
48
+ "status": "error",
49
+ "code": str, # Nom de l'exception
50
+ "message": str # Message de l'exception
51
+ }
52
+
53
+ Notes:
54
+ - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
55
+ - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
56
+ """
57
+ try:
58
+ dict_result = database.fetch_tags()
59
+ if dict_result["status"] == "ok":
60
+ return {"status": "ok", "tags": dict_result["result"]}
61
+ else:
62
+ return dict_result
63
+ except Exception as e:
64
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
65
+
66
+ @app.get("/get_articles_with_tags")
67
+ def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
68
+ """
69
+ Récupère les articles associés à une ou plusieurs tags spécifiés.
70
+
71
+ Args:
72
+ tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
73
+ Doit contenir au moins un tag.
74
+
75
+ Returns:
76
+ Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
77
+ - Si succès :
78
+ {
79
+ "status": "ok",
80
+ "tags": List[str], # Tags utilisés pour filtrer
81
+ "articles": List[Dict] # Liste des articles correspondants
82
+ }
83
+ Chaque article est un dictionnaire contenant :
84
+ - 'article_id': int, ID de l'article
85
+ - 'article_title': str, Titre de l'article
86
+ - 'article_url': str, URL de l'article
87
+ - En cas d'erreur :
88
+ {
89
+ "status": "error",
90
+ "code": str, # Code d'erreur ou nom de l'exception
91
+ "message": str # Message d'erreur
92
+ }
93
+
94
+ Notes:
95
+ - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
96
+ - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
97
+ """
98
+ try:
99
+ dict_result = database.fetch_articles_by_tags(tags)
100
+ if dict_result["status"] == "ok":
101
+ return {"status": "ok",
102
+ "tags": tags,
103
+ "articles": dict_result["result"]}
104
+ else:
105
+ return dict_result
106
+ except Exception as e:
107
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
108
+
109
+
110
+ @app.get("/get_query_results")
111
+ def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
112
+ k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
113
+ k_cross: int = Query(5, description="Nombre de résultats conservés après reranking")
114
+ ) -> Dict[str, Any]:
115
+ """
116
+ Récupère les résultats d'une requête en utilisant deux modèles de recherche.
117
+
118
+ Args:
119
+ query (str): La requête utilisateur pour laquelle récupérer les résultats.
120
+ k_model (int, optional): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
121
+ k_cross (int, optional): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
122
+
123
+ Returns:
124
+ Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
125
+
126
+ Notes:
127
+ - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
128
+ - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
129
+ """
130
+ try:
131
+ dict_result = database.fetch_query_results(query, k_model, k_cross)
132
+ if dict_result["status"] == "ok":
133
+ return {"status": "ok",
134
+ "results": dict_result["result"]}
135
+ else:
136
+ return dict_result
137
+ except Exception as e:
138
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
139
+
140
+ # 🔹 Exemple de modèle d'entrée utilisateur
141
+ class QueryRequest(BaseModel):
142
+ question: str
143
+
144
+ @app.post("/ask")
145
+ async def ask_question(request: QueryRequest):
146
+ try:
147
+ user_query = request.question.strip()
148
+ dict_result = database.fetch_query_results(user_query, k_model=10, k_cross=5)
149
+ if dict_result["status"] == "ok":
150
+ list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
151
+ if not list_chunks:
152
+ answer = ("Je ne dispose pas d’informations sur ce sujet. "
153
+ "Je peux uniquement répondre à des questions sur les articles " \
154
+ "du jeu de données.")
155
+ else:
156
+ # Construction du prompt
157
+ prompt = RAG_PROMPT_TEMPLATE.format(
158
+ context="\n".join(list_chunks),
159
+ question=user_query
160
+ )
161
+ # Génération de la réponse
162
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
163
+ outputs = model.generate(**inputs, max_new_tokens=500)
164
+ generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
165
+ answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
166
+ else:
167
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
168
+ {dict_result['code']} - {dict_result['message']}."
169
+ return {"answer": answer}
170
+ except Exception as e:
171
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
172
+ {type(e).__name__} - {str(e)}."
173
  return {"answer": answer}