Update app.py
Browse files
app.py
CHANGED
|
@@ -24,7 +24,6 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
| 24 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 25 |
import time
|
| 26 |
|
| 27 |
-
# Ajoutez ces imports au début du fichier
|
| 28 |
from fastapi.responses import StreamingResponse
|
| 29 |
import json
|
| 30 |
import asyncio
|
|
@@ -128,11 +127,9 @@ def retrieve_relevant_context(query, embedding_model, mongo_collection, k=5):
|
|
| 128 |
|
| 129 |
docs = list(mongo_collection.find({}, {"text": 1, "embedding": 1}))
|
| 130 |
|
| 131 |
-
# Format pour affichage de debug
|
| 132 |
print(f"[DEBUG] Recherche de contexte pour: '{query}'")
|
| 133 |
print(f"[DEBUG] {len(docs)} documents trouvés dans la base de données")
|
| 134 |
|
| 135 |
-
# Si pas de documents, retourner chaîne vide
|
| 136 |
if not docs:
|
| 137 |
print("[DEBUG] Aucun document dans la collection. RAG désactivé.")
|
| 138 |
return ""
|
|
@@ -147,7 +144,6 @@ def retrieve_relevant_context(query, embedding_model, mongo_collection, k=5):
|
|
| 147 |
sim = cosine_similarity([query_embedding], [doc["embedding"]])[0][0]
|
| 148 |
similarities.append((sim, i, doc["text"]))
|
| 149 |
|
| 150 |
-
# Trier par similarité décroissante
|
| 151 |
similarities.sort(reverse=True)
|
| 152 |
|
| 153 |
# Afficher les top k documents avec leurs scores
|
|
@@ -159,7 +155,6 @@ def retrieve_relevant_context(query, embedding_model, mongo_collection, k=5):
|
|
| 159 |
top_k_docs.append(text)
|
| 160 |
print("==========================\n")
|
| 161 |
|
| 162 |
-
# Retourner le texte joint
|
| 163 |
return "\n\n".join(top_k_docs)
|
| 164 |
|
| 165 |
|
|
@@ -171,7 +166,6 @@ async def get_admin_user(request: Request):
|
|
| 171 |
return user
|
| 172 |
|
| 173 |
|
| 174 |
-
# Initialiser le modèle d'embedding (à faire une seule fois au démarrage)
|
| 175 |
try:
|
| 176 |
embedding_model = HuggingFaceEmbeddings(model_name="shtilev/medical_embedded_v2")
|
| 177 |
print("✅ Modèle d'embedding médical chargé avec succès")
|
|
@@ -199,42 +193,35 @@ async def upload_pdf(
|
|
| 199 |
current_user: dict = Depends(get_admin_user)
|
| 200 |
):
|
| 201 |
try:
|
| 202 |
-
# Vérifier que le fichier est un PDF
|
| 203 |
if not file.filename.endswith('.pdf'):
|
| 204 |
raise HTTPException(status_code=400, detail="Le fichier doit être un PDF")
|
| 205 |
|
| 206 |
-
# Lire le contenu du PDF
|
| 207 |
contents = await file.read()
|
| 208 |
pdf_file = BytesIO(contents)
|
| 209 |
|
| 210 |
-
# Extraire le texte du PDF
|
| 211 |
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
| 212 |
text_content = ""
|
| 213 |
for page_num in range(len(pdf_reader.pages)):
|
| 214 |
text_content += pdf_reader.pages[page_num].extract_text() + "\n"
|
| 215 |
|
| 216 |
-
# Générer un embedding pour l'ensemble du texte si le modèle est disponible
|
| 217 |
embedding = None
|
| 218 |
if embedding_model:
|
| 219 |
try:
|
| 220 |
# Limiter la taille du texte si nécessaire
|
| 221 |
max_length = 5000
|
| 222 |
truncated_text = text_content[:max_length]
|
| 223 |
-
embedding = embedding_model.
|
| 224 |
except Exception as e:
|
| 225 |
print(f"Erreur lors de la génération de l'embedding: {str(e)}")
|
| 226 |
|
| 227 |
-
# Générer un identifiant unique pour le document
|
| 228 |
doc_id = ObjectId()
|
| 229 |
|
| 230 |
-
# Enregistrer le fichier original
|
| 231 |
pdf_path = f"files/{str(doc_id)}.pdf"
|
| 232 |
os.makedirs("files", exist_ok=True)
|
| 233 |
with open(pdf_path, "wb") as f:
|
| 234 |
pdf_file.seek(0)
|
| 235 |
f.write(contents)
|
| 236 |
|
| 237 |
-
# Créer un objet document dans MongoDB
|
| 238 |
document = {
|
| 239 |
"_id": doc_id,
|
| 240 |
"text": text_content,
|
|
@@ -266,10 +253,8 @@ async def upload_pdf(
|
|
| 266 |
@app.get("/api/admin/knowledge")
|
| 267 |
async def list_documents(current_user: dict = Depends(get_admin_user)):
|
| 268 |
try:
|
| 269 |
-
# Récupérer les documents triés par date (plus récents en premier)
|
| 270 |
documents = list(db.connaissances.find().sort("upload_date", -1))
|
| 271 |
|
| 272 |
-
# Convertir les types non sérialisables (ObjectId, datetime, etc.)
|
| 273 |
result = []
|
| 274 |
for doc in documents:
|
| 275 |
doc_safe = {
|
|
@@ -291,7 +276,6 @@ async def list_documents(current_user: dict = Depends(get_admin_user)):
|
|
| 291 |
@app.delete("/api/admin/knowledge/{document_id}")
|
| 292 |
async def delete_document(document_id: str, current_user: dict = Depends(get_admin_user)):
|
| 293 |
try:
|
| 294 |
-
# Convertir l'ID string en ObjectId
|
| 295 |
try:
|
| 296 |
doc_id = ObjectId(document_id)
|
| 297 |
except Exception:
|
|
@@ -316,7 +300,6 @@ async def delete_document(document_id: str, current_user: dict = Depends(get_adm
|
|
| 316 |
print(f"Fichier supprimé: {pdf_path}")
|
| 317 |
except Exception as e:
|
| 318 |
print(f"Erreur lors de la suppression du fichier: {str(e)}")
|
| 319 |
-
# On continue même si la suppression du fichier échoue
|
| 320 |
|
| 321 |
return {"success": True, "message": "Document supprimé avec succès"}
|
| 322 |
|
|
@@ -341,7 +324,6 @@ async def login(request: Request, response: Response):
|
|
| 341 |
user_id = str(user["_id"])
|
| 342 |
username = f"{user['prenom']} {user['nom']}"
|
| 343 |
|
| 344 |
-
# Stocker la session en base de données
|
| 345 |
db.sessions.insert_one({
|
| 346 |
"session_id": session_id,
|
| 347 |
"user_id": user_id,
|
|
@@ -349,7 +331,6 @@ async def login(request: Request, response: Response):
|
|
| 349 |
"expires_at": datetime.utcnow() + timedelta(days=7)
|
| 350 |
})
|
| 351 |
|
| 352 |
-
# Cookie configuré pour fonctionner sur HF Spaces
|
| 353 |
response.set_cookie(
|
| 354 |
key="session_id",
|
| 355 |
value=session_id,
|
|
@@ -409,7 +390,6 @@ async def get_current_user(request: Request):
|
|
| 409 |
|
| 410 |
return user
|
| 411 |
|
| 412 |
-
# Endpoint pour déconnexion
|
| 413 |
@app.post("/api/logout")
|
| 414 |
async def logout(request: Request, response: Response):
|
| 415 |
session_id = request.cookies.get("session_id")
|
|
@@ -490,18 +470,15 @@ async def chat(request: Request):
|
|
| 490 |
user_message = data.get("message", "").strip()
|
| 491 |
conversation_id = data.get("conversation_id")
|
| 492 |
|
| 493 |
-
# ② Vérification du message utilisateur
|
| 494 |
if not user_message:
|
| 495 |
raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
|
| 496 |
|
| 497 |
-
# ③ Authentification (on continue même si non authentifié)
|
| 498 |
current_user = None
|
| 499 |
try:
|
| 500 |
current_user = await get_current_user(request)
|
| 501 |
except HTTPException:
|
| 502 |
pass
|
| 503 |
|
| 504 |
-
# ④ Gestion du quota de tokens pour l'utilisateur/authenticated convo
|
| 505 |
current_tokens = 0
|
| 506 |
message_tokens = 0
|
| 507 |
if current_user and conversation_id:
|
|
@@ -521,7 +498,6 @@ async def chat(request: Request):
|
|
| 521 |
"tokens_limit": MAX_TOKENS
|
| 522 |
}, status_code=403)
|
| 523 |
|
| 524 |
-
# ⑤ Sauvegarde immédiate du message utilisateur
|
| 525 |
if conversation_id and current_user:
|
| 526 |
db.messages.insert_one({
|
| 527 |
"conversation_id": conversation_id,
|
|
@@ -531,7 +507,6 @@ async def chat(request: Request):
|
|
| 531 |
"timestamp": datetime.utcnow()
|
| 532 |
})
|
| 533 |
|
| 534 |
-
# ⑥ Détection d'une question sur l'historique
|
| 535 |
is_history_question = any(
|
| 536 |
phrase in user_message.lower()
|
| 537 |
for phrase in [
|
|
@@ -541,7 +516,6 @@ async def chat(request: Request):
|
|
| 541 |
]
|
| 542 |
)
|
| 543 |
|
| 544 |
-
# ⑦ Initialize conversation history if it doesn't exist
|
| 545 |
if conversation_id not in conversation_history:
|
| 546 |
conversation_history[conversation_id] = []
|
| 547 |
# If there's existing conversation in DB, load it to memory
|
|
@@ -556,9 +530,7 @@ async def chat(request: Request):
|
|
| 556 |
else:
|
| 557 |
conversation_history[conversation_id].append(f"Réponse : {msg['text']}")
|
| 558 |
|
| 559 |
-
# ─── Gestion spécialisée des questions d'historique ─────────
|
| 560 |
if is_history_question:
|
| 561 |
-
# Recueillir les vraies questions (pas les méta-questions sur l'historique)
|
| 562 |
actual_questions = []
|
| 563 |
|
| 564 |
if conversation_id in conversation_history:
|
|
@@ -574,23 +546,19 @@ async def chat(request: Request):
|
|
| 574 |
if not is_meta:
|
| 575 |
actual_questions.append(q_text)
|
| 576 |
|
| 577 |
-
# Cas 1: Aucune question précédente
|
| 578 |
if not actual_questions:
|
| 579 |
return JSONResponse({
|
| 580 |
"response": "Vous n'avez pas encore posé de question dans cette conversation. C'est notre premier échange."
|
| 581 |
})
|
| 582 |
|
| 583 |
-
# Détection dynamique du numéro de question demandé
|
| 584 |
question_number = None
|
| 585 |
|
| 586 |
-
# Chercher les patterns de questions spécifiques
|
| 587 |
if any(p in user_message.lower() for p in ["première question", "1ère question", "1ere question"]):
|
| 588 |
question_number = 1
|
| 589 |
elif any(p in user_message.lower() for p in ["deuxième question", "2ème question", "2eme question", "seconde question"]):
|
| 590 |
question_number = 2
|
| 591 |
else:
|
| 592 |
import re
|
| 593 |
-
# Chercher des patterns comme "3ème question", "4e question", etc.
|
| 594 |
match = re.search(r'(\d+)[eèiéê]*m*e* question', user_message.lower())
|
| 595 |
if match:
|
| 596 |
try:
|
|
@@ -598,7 +566,6 @@ async def chat(request: Request):
|
|
| 598 |
except:
|
| 599 |
pass
|
| 600 |
|
| 601 |
-
# Si on a identifié un numéro de question spécifique
|
| 602 |
if question_number is not None:
|
| 603 |
if 0 < question_number <= len(actual_questions):
|
| 604 |
suffix = "ère" if question_number == 1 else "ème"
|
|
@@ -610,7 +577,6 @@ async def chat(request: Request):
|
|
| 610 |
"response": f"Vous n'avez pas encore posé {question_number} questions dans cette conversation."
|
| 611 |
})
|
| 612 |
|
| 613 |
-
# Cas général: liste toutes les questions
|
| 614 |
else:
|
| 615 |
if len(actual_questions) == 1:
|
| 616 |
return JSONResponse({
|
|
@@ -621,30 +587,23 @@ async def chat(request: Request):
|
|
| 621 |
return JSONResponse({
|
| 622 |
"response": f"Voici les questions que vous avez posées dans cette conversation :\n\n{question_list}"
|
| 623 |
})
|
| 624 |
-
# ───────────────────────────────────────────────────────────────
|
| 625 |
|
| 626 |
-
# ⑧ RAG – récupération de contexte si ce n'est pas une question d'historique
|
| 627 |
context = None
|
| 628 |
if not is_history_question and embedding_model:
|
| 629 |
context = retrieve_relevant_context(user_message, embedding_model, db.connaissances, k=5)
|
| 630 |
-
# Store context in history
|
| 631 |
if context and conversation_id:
|
| 632 |
conversation_history[conversation_id].append(f"Contexte : {context}")
|
| 633 |
|
| 634 |
-
# Add current question to history
|
| 635 |
if conversation_id:
|
| 636 |
conversation_history[conversation_id].append(f"Question : {user_message}")
|
| 637 |
|
| 638 |
-
# ⑨ Construction du prompt système avec contexte enrichi
|
| 639 |
system_prompt = (
|
| 640 |
"Tu es un chatbot spécialisé dans la santé mentale, et plus particulièrement la schizophrénie. "
|
| 641 |
"Tu réponds de façon fiable, claire et empathique, en t'appuyant uniquement sur des sources médicales et en français. "
|
| 642 |
)
|
| 643 |
|
| 644 |
-
# Construire un contexte enrichi qui combine RAG et résumé de l'historique
|
| 645 |
enriched_context = ""
|
| 646 |
|
| 647 |
-
# Ajouter un résumé des questions précédentes (maximum 3)
|
| 648 |
if conversation_id in conversation_history:
|
| 649 |
actual_questions = []
|
| 650 |
for msg in conversation_history[conversation_id]:
|
|
@@ -656,10 +615,9 @@ async def chat(request: Request):
|
|
| 656 |
"ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
|
| 657 |
"c'était quoi ma", "quelle était ma", "mes questions"
|
| 658 |
])
|
| 659 |
-
if not is_meta and q_text != user_message:
|
| 660 |
actual_questions.append(q_text)
|
| 661 |
|
| 662 |
-
# Ajouter les 3 dernières questions au contexte
|
| 663 |
if actual_questions:
|
| 664 |
recent_questions = actual_questions[-5:] # 3 dernières questions
|
| 665 |
enriched_context += "Historique récent des questions:\n"
|
|
@@ -667,13 +625,11 @@ async def chat(request: Request):
|
|
| 667 |
enriched_context += f"- Question précédente {len(recent_questions)-i}: {q}\n"
|
| 668 |
enriched_context += "\n"
|
| 669 |
|
| 670 |
-
# Ajouter le contexte RAG s'il existe
|
| 671 |
if context:
|
| 672 |
enriched_context += "Contexte médical pertinent:\n"
|
| 673 |
enriched_context += context
|
| 674 |
enriched_context += "\n\n"
|
| 675 |
|
| 676 |
-
# Compléter le prompt système
|
| 677 |
if enriched_context:
|
| 678 |
system_prompt += (
|
| 679 |
f"\n\n{enriched_context}\n\n"
|
|
@@ -686,12 +642,9 @@ async def chat(request: Request):
|
|
| 686 |
"Si tu ne sais pas répondre, indique-le clairement et suggère de consulter un professionnel de santé."
|
| 687 |
)
|
| 688 |
|
| 689 |
-
# ⑩ Construction de l'historique conversationnel pour le modèle
|
| 690 |
messages = [{"role": "system", "content": system_prompt}]
|
| 691 |
|
| 692 |
-
# Format conversation history for the LLM
|
| 693 |
if conversation_id and len(conversation_history.get(conversation_id, [])) > 0:
|
| 694 |
-
# Convert our history format to chat format (last 10 exchanges)
|
| 695 |
history = conversation_history[conversation_id]
|
| 696 |
for i in range(0, min(20, len(history)-1), 2):
|
| 697 |
if i+1 < len(history):
|
|
@@ -703,10 +656,8 @@ async def chat(request: Request):
|
|
| 703 |
assistant_text = history[i+1].replace("Réponse : ", "")
|
| 704 |
messages.append({"role": "assistant", "content": assistant_text})
|
| 705 |
|
| 706 |
-
# Add current user message
|
| 707 |
messages.append({"role": "user", "content": user_message})
|
| 708 |
|
| 709 |
-
# ⑫ Appel à l'API Hugging Face
|
| 710 |
try:
|
| 711 |
completion = hf_client.chat.completions.create(
|
| 712 |
model="mistralai/Mistral-7B-Instruct-v0.3",
|
|
@@ -725,15 +676,12 @@ async def chat(request: Request):
|
|
| 725 |
)
|
| 726 |
bot_response = fallback
|
| 727 |
|
| 728 |
-
# Add bot response to history
|
| 729 |
if conversation_id:
|
| 730 |
conversation_history[conversation_id].append(f"Réponse : {bot_response}")
|
| 731 |
|
| 732 |
-
# Keep history to a reasonable size
|
| 733 |
if len(conversation_history[conversation_id]) > 50: # 25 exchanges
|
| 734 |
conversation_history[conversation_id] = conversation_history[conversation_id][-50:]
|
| 735 |
|
| 736 |
-
# ⑬ Sauvegarde de la réponse de l'assistant + mise à jour tokens & last_message
|
| 737 |
if conversation_id and current_user:
|
| 738 |
db.messages.insert_one({
|
| 739 |
"conversation_id": conversation_id,
|
|
@@ -753,7 +701,6 @@ async def chat(request: Request):
|
|
| 753 |
}}
|
| 754 |
)
|
| 755 |
|
| 756 |
-
# ⑭ Retour de la réponse finale
|
| 757 |
return {"response": bot_response}
|
| 758 |
|
| 759 |
|
|
@@ -764,16 +711,12 @@ def simulate_token_count(text):
|
|
| 764 |
if not text:
|
| 765 |
return 0
|
| 766 |
|
| 767 |
-
# Prétraitement pour mieux gérer les cas spéciaux
|
| 768 |
text = text.replace('\n', ' \n ')
|
| 769 |
|
| 770 |
-
# Compter les caractères spéciaux et espaces
|
| 771 |
spaces_and_punct = sum(1 for c in text if c.isspace() or c in ',.;:!?()[]{}"\'`-_=+<>/@#$%^&*|\\')
|
| 772 |
|
| 773 |
-
# Compter les chiffres
|
| 774 |
digits = sum(1 for c in text if c.isdigit())
|
| 775 |
|
| 776 |
-
# Compter les mots courts et tokens spéciaux
|
| 777 |
words = text.split()
|
| 778 |
short_words = sum(1 for w in words if len(w) <= 2)
|
| 779 |
|
|
@@ -781,10 +724,8 @@ def simulate_token_count(text):
|
|
| 781 |
code_blocks = len(re.findall(r'```[\s\S]*?```', text))
|
| 782 |
urls = len(re.findall(r'https?://\S+', text))
|
| 783 |
|
| 784 |
-
# Longueur restante après ajustements
|
| 785 |
adjusted_length = len(text) - spaces_and_punct - digits - short_words
|
| 786 |
|
| 787 |
-
# Calcul final avec facteurs de pondération
|
| 788 |
token_count = (
|
| 789 |
adjusted_length / 4 +
|
| 790 |
spaces_and_punct * 0.25 +
|
|
|
|
| 24 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 25 |
import time
|
| 26 |
|
|
|
|
| 27 |
from fastapi.responses import StreamingResponse
|
| 28 |
import json
|
| 29 |
import asyncio
|
|
|
|
| 127 |
|
| 128 |
docs = list(mongo_collection.find({}, {"text": 1, "embedding": 1}))
|
| 129 |
|
|
|
|
| 130 |
print(f"[DEBUG] Recherche de contexte pour: '{query}'")
|
| 131 |
print(f"[DEBUG] {len(docs)} documents trouvés dans la base de données")
|
| 132 |
|
|
|
|
| 133 |
if not docs:
|
| 134 |
print("[DEBUG] Aucun document dans la collection. RAG désactivé.")
|
| 135 |
return ""
|
|
|
|
| 144 |
sim = cosine_similarity([query_embedding], [doc["embedding"]])[0][0]
|
| 145 |
similarities.append((sim, i, doc["text"]))
|
| 146 |
|
|
|
|
| 147 |
similarities.sort(reverse=True)
|
| 148 |
|
| 149 |
# Afficher les top k documents avec leurs scores
|
|
|
|
| 155 |
top_k_docs.append(text)
|
| 156 |
print("==========================\n")
|
| 157 |
|
|
|
|
| 158 |
return "\n\n".join(top_k_docs)
|
| 159 |
|
| 160 |
|
|
|
|
| 166 |
return user
|
| 167 |
|
| 168 |
|
|
|
|
| 169 |
try:
|
| 170 |
embedding_model = HuggingFaceEmbeddings(model_name="shtilev/medical_embedded_v2")
|
| 171 |
print("✅ Modèle d'embedding médical chargé avec succès")
|
|
|
|
| 193 |
current_user: dict = Depends(get_admin_user)
|
| 194 |
):
|
| 195 |
try:
|
|
|
|
| 196 |
if not file.filename.endswith('.pdf'):
|
| 197 |
raise HTTPException(status_code=400, detail="Le fichier doit être un PDF")
|
| 198 |
|
|
|
|
| 199 |
contents = await file.read()
|
| 200 |
pdf_file = BytesIO(contents)
|
| 201 |
|
|
|
|
| 202 |
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
| 203 |
text_content = ""
|
| 204 |
for page_num in range(len(pdf_reader.pages)):
|
| 205 |
text_content += pdf_reader.pages[page_num].extract_text() + "\n"
|
| 206 |
|
|
|
|
| 207 |
embedding = None
|
| 208 |
if embedding_model:
|
| 209 |
try:
|
| 210 |
# Limiter la taille du texte si nécessaire
|
| 211 |
max_length = 5000
|
| 212 |
truncated_text = text_content[:max_length]
|
| 213 |
+
embedding = embedding_model.embed_query(truncated_text)
|
| 214 |
except Exception as e:
|
| 215 |
print(f"Erreur lors de la génération de l'embedding: {str(e)}")
|
| 216 |
|
|
|
|
| 217 |
doc_id = ObjectId()
|
| 218 |
|
|
|
|
| 219 |
pdf_path = f"files/{str(doc_id)}.pdf"
|
| 220 |
os.makedirs("files", exist_ok=True)
|
| 221 |
with open(pdf_path, "wb") as f:
|
| 222 |
pdf_file.seek(0)
|
| 223 |
f.write(contents)
|
| 224 |
|
|
|
|
| 225 |
document = {
|
| 226 |
"_id": doc_id,
|
| 227 |
"text": text_content,
|
|
|
|
| 253 |
@app.get("/api/admin/knowledge")
|
| 254 |
async def list_documents(current_user: dict = Depends(get_admin_user)):
|
| 255 |
try:
|
|
|
|
| 256 |
documents = list(db.connaissances.find().sort("upload_date", -1))
|
| 257 |
|
|
|
|
| 258 |
result = []
|
| 259 |
for doc in documents:
|
| 260 |
doc_safe = {
|
|
|
|
| 276 |
@app.delete("/api/admin/knowledge/{document_id}")
|
| 277 |
async def delete_document(document_id: str, current_user: dict = Depends(get_admin_user)):
|
| 278 |
try:
|
|
|
|
| 279 |
try:
|
| 280 |
doc_id = ObjectId(document_id)
|
| 281 |
except Exception:
|
|
|
|
| 300 |
print(f"Fichier supprimé: {pdf_path}")
|
| 301 |
except Exception as e:
|
| 302 |
print(f"Erreur lors de la suppression du fichier: {str(e)}")
|
|
|
|
| 303 |
|
| 304 |
return {"success": True, "message": "Document supprimé avec succès"}
|
| 305 |
|
|
|
|
| 324 |
user_id = str(user["_id"])
|
| 325 |
username = f"{user['prenom']} {user['nom']}"
|
| 326 |
|
|
|
|
| 327 |
db.sessions.insert_one({
|
| 328 |
"session_id": session_id,
|
| 329 |
"user_id": user_id,
|
|
|
|
| 331 |
"expires_at": datetime.utcnow() + timedelta(days=7)
|
| 332 |
})
|
| 333 |
|
|
|
|
| 334 |
response.set_cookie(
|
| 335 |
key="session_id",
|
| 336 |
value=session_id,
|
|
|
|
| 390 |
|
| 391 |
return user
|
| 392 |
|
|
|
|
| 393 |
@app.post("/api/logout")
|
| 394 |
async def logout(request: Request, response: Response):
|
| 395 |
session_id = request.cookies.get("session_id")
|
|
|
|
| 470 |
user_message = data.get("message", "").strip()
|
| 471 |
conversation_id = data.get("conversation_id")
|
| 472 |
|
|
|
|
| 473 |
if not user_message:
|
| 474 |
raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
|
| 475 |
|
|
|
|
| 476 |
current_user = None
|
| 477 |
try:
|
| 478 |
current_user = await get_current_user(request)
|
| 479 |
except HTTPException:
|
| 480 |
pass
|
| 481 |
|
|
|
|
| 482 |
current_tokens = 0
|
| 483 |
message_tokens = 0
|
| 484 |
if current_user and conversation_id:
|
|
|
|
| 498 |
"tokens_limit": MAX_TOKENS
|
| 499 |
}, status_code=403)
|
| 500 |
|
|
|
|
| 501 |
if conversation_id and current_user:
|
| 502 |
db.messages.insert_one({
|
| 503 |
"conversation_id": conversation_id,
|
|
|
|
| 507 |
"timestamp": datetime.utcnow()
|
| 508 |
})
|
| 509 |
|
|
|
|
| 510 |
is_history_question = any(
|
| 511 |
phrase in user_message.lower()
|
| 512 |
for phrase in [
|
|
|
|
| 516 |
]
|
| 517 |
)
|
| 518 |
|
|
|
|
| 519 |
if conversation_id not in conversation_history:
|
| 520 |
conversation_history[conversation_id] = []
|
| 521 |
# If there's existing conversation in DB, load it to memory
|
|
|
|
| 530 |
else:
|
| 531 |
conversation_history[conversation_id].append(f"Réponse : {msg['text']}")
|
| 532 |
|
|
|
|
| 533 |
if is_history_question:
|
|
|
|
| 534 |
actual_questions = []
|
| 535 |
|
| 536 |
if conversation_id in conversation_history:
|
|
|
|
| 546 |
if not is_meta:
|
| 547 |
actual_questions.append(q_text)
|
| 548 |
|
|
|
|
| 549 |
if not actual_questions:
|
| 550 |
return JSONResponse({
|
| 551 |
"response": "Vous n'avez pas encore posé de question dans cette conversation. C'est notre premier échange."
|
| 552 |
})
|
| 553 |
|
|
|
|
| 554 |
question_number = None
|
| 555 |
|
|
|
|
| 556 |
if any(p in user_message.lower() for p in ["première question", "1ère question", "1ere question"]):
|
| 557 |
question_number = 1
|
| 558 |
elif any(p in user_message.lower() for p in ["deuxième question", "2ème question", "2eme question", "seconde question"]):
|
| 559 |
question_number = 2
|
| 560 |
else:
|
| 561 |
import re
|
|
|
|
| 562 |
match = re.search(r'(\d+)[eèiéê]*m*e* question', user_message.lower())
|
| 563 |
if match:
|
| 564 |
try:
|
|
|
|
| 566 |
except:
|
| 567 |
pass
|
| 568 |
|
|
|
|
| 569 |
if question_number is not None:
|
| 570 |
if 0 < question_number <= len(actual_questions):
|
| 571 |
suffix = "ère" if question_number == 1 else "ème"
|
|
|
|
| 577 |
"response": f"Vous n'avez pas encore posé {question_number} questions dans cette conversation."
|
| 578 |
})
|
| 579 |
|
|
|
|
| 580 |
else:
|
| 581 |
if len(actual_questions) == 1:
|
| 582 |
return JSONResponse({
|
|
|
|
| 587 |
return JSONResponse({
|
| 588 |
"response": f"Voici les questions que vous avez posées dans cette conversation :\n\n{question_list}"
|
| 589 |
})
|
|
|
|
| 590 |
|
|
|
|
| 591 |
context = None
|
| 592 |
if not is_history_question and embedding_model:
|
| 593 |
context = retrieve_relevant_context(user_message, embedding_model, db.connaissances, k=5)
|
|
|
|
| 594 |
if context and conversation_id:
|
| 595 |
conversation_history[conversation_id].append(f"Contexte : {context}")
|
| 596 |
|
|
|
|
| 597 |
if conversation_id:
|
| 598 |
conversation_history[conversation_id].append(f"Question : {user_message}")
|
| 599 |
|
|
|
|
| 600 |
system_prompt = (
|
| 601 |
"Tu es un chatbot spécialisé dans la santé mentale, et plus particulièrement la schizophrénie. "
|
| 602 |
"Tu réponds de façon fiable, claire et empathique, en t'appuyant uniquement sur des sources médicales et en français. "
|
| 603 |
)
|
| 604 |
|
|
|
|
| 605 |
enriched_context = ""
|
| 606 |
|
|
|
|
| 607 |
if conversation_id in conversation_history:
|
| 608 |
actual_questions = []
|
| 609 |
for msg in conversation_history[conversation_id]:
|
|
|
|
| 615 |
"ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
|
| 616 |
"c'était quoi ma", "quelle était ma", "mes questions"
|
| 617 |
])
|
| 618 |
+
if not is_meta and q_text != user_message:
|
| 619 |
actual_questions.append(q_text)
|
| 620 |
|
|
|
|
| 621 |
if actual_questions:
|
| 622 |
recent_questions = actual_questions[-5:] # 3 dernières questions
|
| 623 |
enriched_context += "Historique récent des questions:\n"
|
|
|
|
| 625 |
enriched_context += f"- Question précédente {len(recent_questions)-i}: {q}\n"
|
| 626 |
enriched_context += "\n"
|
| 627 |
|
|
|
|
| 628 |
if context:
|
| 629 |
enriched_context += "Contexte médical pertinent:\n"
|
| 630 |
enriched_context += context
|
| 631 |
enriched_context += "\n\n"
|
| 632 |
|
|
|
|
| 633 |
if enriched_context:
|
| 634 |
system_prompt += (
|
| 635 |
f"\n\n{enriched_context}\n\n"
|
|
|
|
| 642 |
"Si tu ne sais pas répondre, indique-le clairement et suggère de consulter un professionnel de santé."
|
| 643 |
)
|
| 644 |
|
|
|
|
| 645 |
messages = [{"role": "system", "content": system_prompt}]
|
| 646 |
|
|
|
|
| 647 |
if conversation_id and len(conversation_history.get(conversation_id, [])) > 0:
|
|
|
|
| 648 |
history = conversation_history[conversation_id]
|
| 649 |
for i in range(0, min(20, len(history)-1), 2):
|
| 650 |
if i+1 < len(history):
|
|
|
|
| 656 |
assistant_text = history[i+1].replace("Réponse : ", "")
|
| 657 |
messages.append({"role": "assistant", "content": assistant_text})
|
| 658 |
|
|
|
|
| 659 |
messages.append({"role": "user", "content": user_message})
|
| 660 |
|
|
|
|
| 661 |
try:
|
| 662 |
completion = hf_client.chat.completions.create(
|
| 663 |
model="mistralai/Mistral-7B-Instruct-v0.3",
|
|
|
|
| 676 |
)
|
| 677 |
bot_response = fallback
|
| 678 |
|
|
|
|
| 679 |
if conversation_id:
|
| 680 |
conversation_history[conversation_id].append(f"Réponse : {bot_response}")
|
| 681 |
|
|
|
|
| 682 |
if len(conversation_history[conversation_id]) > 50: # 25 exchanges
|
| 683 |
conversation_history[conversation_id] = conversation_history[conversation_id][-50:]
|
| 684 |
|
|
|
|
| 685 |
if conversation_id and current_user:
|
| 686 |
db.messages.insert_one({
|
| 687 |
"conversation_id": conversation_id,
|
|
|
|
| 701 |
}}
|
| 702 |
)
|
| 703 |
|
|
|
|
| 704 |
return {"response": bot_response}
|
| 705 |
|
| 706 |
|
|
|
|
| 711 |
if not text:
|
| 712 |
return 0
|
| 713 |
|
|
|
|
| 714 |
text = text.replace('\n', ' \n ')
|
| 715 |
|
|
|
|
| 716 |
spaces_and_punct = sum(1 for c in text if c.isspace() or c in ',.;:!?()[]{}"\'`-_=+<>/@#$%^&*|\\')
|
| 717 |
|
|
|
|
| 718 |
digits = sum(1 for c in text if c.isdigit())
|
| 719 |
|
|
|
|
| 720 |
words = text.split()
|
| 721 |
short_words = sum(1 for w in words if len(w) <= 2)
|
| 722 |
|
|
|
|
| 724 |
code_blocks = len(re.findall(r'```[\s\S]*?```', text))
|
| 725 |
urls = len(re.findall(r'https?://\S+', text))
|
| 726 |
|
|
|
|
| 727 |
adjusted_length = len(text) - spaces_and_punct - digits - short_words
|
| 728 |
|
|
|
|
| 729 |
token_count = (
|
| 730 |
adjusted_length / 4 +
|
| 731 |
spaces_and_punct * 0.25 +
|