Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,8 @@ import requests
|
|
| 3 |
import pandas as pd
|
| 4 |
import os
|
| 5 |
import time
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
# URL du backend Hugging Face Space
|
|
@@ -10,6 +12,14 @@ API_URL = os.getenv('API_URL')
|
|
| 10 |
API_URL_ASK = API_URL+"/get_answer"
|
| 11 |
SPACE_URL = "https://huggingface.co/api/spaces/Loren/api_search_articles"
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def get_tags():
|
| 14 |
resp = requests.get(f"{API_URL}/get_tags")
|
| 15 |
if resp.status_code != 200:
|
|
@@ -178,16 +188,9 @@ def get_answer_with_query(query, use_rerank, history):
|
|
| 178 |
if not query:
|
| 179 |
raise gr.Error("❌ Erreur : aucun query fourni.")
|
| 180 |
return None
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
}
|
| 185 |
-
headers = {
|
| 186 |
-
"Content-Type": "application/json"
|
| 187 |
-
}
|
| 188 |
-
resp = requests.post(API_URL_ASK, json=payload, headers=headers)
|
| 189 |
-
|
| 190 |
-
# Vérification du statut HTTP
|
| 191 |
if resp.status_code != 200:
|
| 192 |
raise gr.Error(f"❌ Erreur : {resp.status_code}")
|
| 193 |
return None
|
|
@@ -196,6 +199,23 @@ def get_answer_with_query(query, use_rerank, history):
|
|
| 196 |
raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}")
|
| 197 |
return None
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
df = pd.DataFrame(dict_resp["results"])
|
| 200 |
# Convertir les URLs en liens HTML cliquables
|
| 201 |
df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>')
|
|
@@ -244,6 +264,7 @@ def get_answer_with_query(query, use_rerank, history):
|
|
| 244 |
{html}
|
| 245 |
</div>
|
| 246 |
"""
|
|
|
|
| 247 |
history.append((query, dict_resp['answer']))
|
| 248 |
return "", history, styled_html
|
| 249 |
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
import os
|
| 5 |
import time
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
+
from templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
|
| 8 |
|
| 9 |
|
| 10 |
# URL du backend Hugging Face Space
|
|
|
|
| 12 |
API_URL_ASK = API_URL+"/get_answer"
|
| 13 |
SPACE_URL = "https://huggingface.co/api/spaces/Loren/api_search_articles"
|
| 14 |
|
| 15 |
+
# Chargement du modèle génératif
|
| 16 |
+
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
|
| 19 |
+
torch_dtype=torch.float16,
|
| 20 |
+
device_map="auto"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
def get_tags():
|
| 24 |
resp = requests.get(f"{API_URL}/get_tags")
|
| 25 |
if resp.status_code != 200:
|
|
|
|
| 188 |
if not query:
|
| 189 |
raise gr.Error("❌ Erreur : aucun query fourni.")
|
| 190 |
return None
|
| 191 |
+
params = {"query": query, "use_rerank": use_rerank}
|
| 192 |
+
|
| 193 |
+
resp = requests.get(f"{API_URL}/get_query_results", params=params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
if resp.status_code != 200:
|
| 195 |
raise gr.Error(f"❌ Erreur : {resp.status_code}")
|
| 196 |
return None
|
|
|
|
| 199 |
raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}")
|
| 200 |
return None
|
| 201 |
|
| 202 |
+
list_chunks = [resp['chunk_text'] for resp in dict_resp['result']]
|
| 203 |
+
if not list_chunks:
|
| 204 |
+
answer = ("Je ne dispose pas d’informations sur ce sujet. "
|
| 205 |
+
"Je peux uniquement répondre à des questions sur les articles " \
|
| 206 |
+
"du jeu de données.")
|
| 207 |
+
else:
|
| 208 |
+
# Construction du prompt
|
| 209 |
+
prompt = RAG_PROMPT_TEMPLATE.format(
|
| 210 |
+
context="\n".join(list_chunks),
|
| 211 |
+
question=user_query
|
| 212 |
+
)
|
| 213 |
+
# Génération de la réponse
|
| 214 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 215 |
+
outputs = model.generate(**inputs, max_new_tokens=500)
|
| 216 |
+
generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
|
| 217 |
+
answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 218 |
+
|
| 219 |
df = pd.DataFrame(dict_resp["results"])
|
| 220 |
# Convertir les URLs en liens HTML cliquables
|
| 221 |
df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>')
|
|
|
|
| 264 |
{html}
|
| 265 |
</div>
|
| 266 |
"""
|
| 267 |
+
|
| 268 |
history.append((query, dict_resp['answer']))
|
| 269 |
return "", history, styled_html
|
| 270 |
|