Spaces:
Sleeping
Sleeping
Update app/main.py
Browse files- app/main.py +2 -84
app/main.py
CHANGED
|
@@ -7,23 +7,13 @@ from fastapi.responses import HTMLResponse
|
|
| 7 |
from pydantic import BaseModel
|
| 8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
import torch
|
| 10 |
-
from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
|
| 11 |
-
|
| 12 |
|
| 13 |
app = FastAPI(
|
| 14 |
title="Articles API",
|
| 15 |
-
description="API pour
|
| 16 |
version="1.0"
|
| 17 |
)
|
| 18 |
|
| 19 |
-
# Chargement du modèle génératif
|
| 20 |
-
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 21 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 22 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
|
| 23 |
-
torch_dtype=torch.float16,
|
| 24 |
-
device_map="auto"
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
# CORS pour permettre l'accès depuis le navigateur
|
| 28 |
app.add_middleware(
|
| 29 |
CORSMiddleware,
|
|
@@ -152,76 +142,4 @@ def get_query_results(query: str = Query(..., description="Requête de recherche
|
|
| 152 |
except Exception as e:
|
| 153 |
return {"status": "error", "code": type(e).__name__, "message": str(e)}
|
| 154 |
|
| 155 |
-
#
|
| 156 |
-
class QueryRequest(BaseModel):
|
| 157 |
-
question: str
|
| 158 |
-
use_rerank: bool = True
|
| 159 |
-
|
| 160 |
-
@app.post("/get_answer")
|
| 161 |
-
async def ask_question(request: QueryRequest):
|
| 162 |
-
"""
|
| 163 |
-
Traite une question utilisateur en effectuant une recherche dans la base de données
|
| 164 |
-
puis en générant une réponse à l’aide du modèle de langage (RAG).
|
| 165 |
-
|
| 166 |
-
Le fonctionnement se déroule en trois étapes :
|
| 167 |
-
1. Extraction et nettoyage de la question utilisateur.
|
| 168 |
-
2. Recherche des passages pertinents dans la base de données (`fetch_query_results`).
|
| 169 |
-
3. Génération d’une réponse fondée sur les morceaux de texte récupérés (RAG).
|
| 170 |
-
|
| 171 |
-
Paramètres
|
| 172 |
-
----------
|
| 173 |
-
request : QueryRequest
|
| 174 |
-
Objet contenant la question utilisateur sous forme de chaîne de caractères.
|
| 175 |
-
|
| 176 |
-
Retour
|
| 177 |
-
------
|
| 178 |
-
dict
|
| 179 |
-
Un dictionnaire contenant :
|
| 180 |
-
- "status" : "ok" si la requête a réussi, "error" en cas d'échec.
|
| 181 |
-
- "results" : liste des chunks retournés par la base de données (présent seulement si status = "ok").
|
| 182 |
-
- "answer" : réponse générée par le modèle ou message d’erreur.
|
| 183 |
-
|
| 184 |
-
Exceptions
|
| 185 |
-
----------
|
| 186 |
-
Toute exception survenant durant l’exécution est interceptée
|
| 187 |
-
et retournée sous forme d’un message d’erreur dans la clé "answer".
|
| 188 |
-
|
| 189 |
-
Notes
|
| 190 |
-
-----
|
| 191 |
-
- Si aucun chunk pertinent n'est trouvé, la fonction renvoie un message indiquant
|
| 192 |
-
que seules les questions relatives aux articles du jeu de données peuvent être traitées.
|
| 193 |
-
- La génération de la réponse utilise un template RAG et produit jusqu’à 500 tokens.
|
| 194 |
-
"""
|
| 195 |
-
try:
|
| 196 |
-
user_query = request.question.strip()
|
| 197 |
-
use_rerank = request.use_rerank
|
| 198 |
-
dict_result = database.fetch_query_results(user_query, k_model=10,
|
| 199 |
-
k_cross=5, use_rerank=use_rerank)
|
| 200 |
-
if dict_result["status"] == "ok":
|
| 201 |
-
list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
|
| 202 |
-
if not list_chunks:
|
| 203 |
-
answer = ("Je ne dispose pas d’informations sur ce sujet. "
|
| 204 |
-
"Je peux uniquement répondre à des questions sur les articles " \
|
| 205 |
-
"du jeu de données.")
|
| 206 |
-
else:
|
| 207 |
-
# Construction du prompt
|
| 208 |
-
prompt = RAG_PROMPT_TEMPLATE.format(
|
| 209 |
-
context="\n".join(list_chunks),
|
| 210 |
-
question=user_query
|
| 211 |
-
)
|
| 212 |
-
# Génération de la réponse
|
| 213 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 214 |
-
outputs = model.generate(**inputs, max_new_tokens=500)
|
| 215 |
-
generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
|
| 216 |
-
answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 217 |
-
return {"status": "ok",
|
| 218 |
-
"results": dict_result["result"],
|
| 219 |
-
"answer": answer}
|
| 220 |
-
else:
|
| 221 |
-
answer = f"Une erreur est survenue lors de la récupération des informations : \
|
| 222 |
-
{dict_result['code']} - {dict_result['message']}."
|
| 223 |
-
return {"status": "error", "answer": answer}
|
| 224 |
-
except Exception as e:
|
| 225 |
-
answer = f"Une erreur est survenue lors de la récupération des informations : \
|
| 226 |
-
{type(e).__name__} - {str(e)}."
|
| 227 |
-
return {"status": "error", "answer": answer}
|
|
|
|
| 7 |
from pydantic import BaseModel
|
| 8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
import torch
|
|
|
|
|
|
|
| 10 |
|
| 11 |
app = FastAPI(
|
| 12 |
title="Articles API",
|
| 13 |
+
description="API pour interroger la base articles",
|
| 14 |
version="1.0"
|
| 15 |
)
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# CORS pour permettre l'accès depuis le navigateur
|
| 18 |
app.add_middleware(
|
| 19 |
CORSMiddleware,
|
|
|
|
| 142 |
except Exception as e:
|
| 143 |
return {"status": "error", "code": type(e).__name__, "message": str(e)}
|
| 144 |
|
| 145 |
+
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|