Update app.py
Browse files
app.py
CHANGED
|
@@ -20,6 +20,21 @@ import PyPDF2
|
|
| 20 |
from io import BytesIO
|
| 21 |
import uuid
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
SECRET_KEY = secrets.token_hex(32)
|
| 24 |
|
| 25 |
HOST = os.environ.get("API_URL", "0.0.0.0")
|
|
@@ -38,6 +53,9 @@ db_name = os.environ.get("DB_NAME", "chatmed_schizo")
|
|
| 38 |
mongo_client = MongoClient(mongo_uri)
|
| 39 |
db = mongo_client[db_name]
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
app = FastAPI()
|
|
@@ -56,6 +74,95 @@ app.add_middleware(
|
|
| 56 |
allow_headers=["*"],
|
| 57 |
)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
async def get_admin_user(request: Request):
|
| 61 |
user = await get_current_user(request)
|
|
@@ -66,10 +173,24 @@ async def get_admin_user(request: Request):
|
|
| 66 |
|
| 67 |
# Initialiser le modèle d'embedding (à faire une seule fois au démarrage)
|
| 68 |
try:
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
except Exception as e:
|
| 71 |
print(f"Erreur lors du chargement du modèle d'embedding: {str(e)}")
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
@app.post("/api/admin/knowledge/upload")
|
| 74 |
async def upload_pdf(
|
| 75 |
file: UploadFile = File(...),
|
|
@@ -358,33 +479,322 @@ async def invert(text: str):
|
|
| 358 |
HF_TOKEN = os.getenv('REACT_APP_HF_TOKEN')
|
| 359 |
if not HF_TOKEN:
|
| 360 |
raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.")
|
| 361 |
-
|
| 362 |
hf_client = InferenceClient(token=HF_TOKEN)
|
| 363 |
-
|
| 364 |
@app.post("/api/chat")
|
| 365 |
async def chat(request: Request):
|
|
|
|
|
|
|
|
|
|
| 366 |
data = await request.json()
|
| 367 |
user_message = data.get("message", "").strip()
|
|
|
|
|
|
|
|
|
|
| 368 |
if not user_message:
|
| 369 |
raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
try:
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 374 |
-
prompt=f"<s>[INST]
|
| 375 |
max_new_tokens=512,
|
| 376 |
temperature=0.7
|
| 377 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
import traceback
|
| 383 |
-
print(f"Erreur détaillée: {traceback.format_exc()}")
|
| 384 |
-
raise HTTPException(status_code=502, detail=f"Erreur d'inférence HF : {str(e)}")
|
| 385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
@app.get("/data")
|
| 389 |
async def get_data():
|
| 390 |
data = {"data": np.random.rand(100).tolist()}
|
|
@@ -508,6 +918,7 @@ async def delete_conversation(conversation_id: str, current_user: dict = Depends
|
|
| 508 |
except Exception as e:
|
| 509 |
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
|
| 510 |
|
|
|
|
| 511 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
| 512 |
|
| 513 |
if __name__ == "__main__":
|
|
@@ -515,12 +926,12 @@ if __name__ == "__main__":
|
|
| 515 |
|
| 516 |
print(args)
|
| 517 |
uvicorn.run(
|
| 518 |
-
"
|
| 519 |
host=args.host,
|
| 520 |
port=args.port,
|
| 521 |
reload=args.reload,
|
|
|
|
| 522 |
ssl_certfile=args.ssl_certfile,
|
| 523 |
ssl_keyfile=args.ssl_keyfile,
|
| 524 |
)
|
| 525 |
|
| 526 |
-
|
|
|
|
| 20 |
from io import BytesIO
|
| 21 |
import uuid
|
| 22 |
|
| 23 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 24 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
# Ajoutez ces imports au début du fichier
|
| 28 |
+
from fastapi.responses import StreamingResponse
|
| 29 |
+
import json
|
| 30 |
+
import asyncio
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
from langchain_community.document_loaders import PyPDFDirectoryLoader
|
| 34 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 35 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 36 |
+
|
| 37 |
+
|
| 38 |
SECRET_KEY = secrets.token_hex(32)
|
| 39 |
|
| 40 |
HOST = os.environ.get("API_URL", "0.0.0.0")
|
|
|
|
| 53 |
mongo_client = MongoClient(mongo_uri)
|
| 54 |
db = mongo_client[db_name]
|
| 55 |
|
| 56 |
+
SAVE_FOLDER = "files"
|
| 57 |
+
COLLECTION_NAME="connaissances"
|
| 58 |
+
os.makedirs(SAVE_FOLDER, exist_ok=True)
|
| 59 |
|
| 60 |
|
| 61 |
app = FastAPI()
|
|
|
|
| 74 |
allow_headers=["*"],
|
| 75 |
)
|
| 76 |
|
| 77 |
+
def download_pdf(url, save_path, retries=2, delay=3):
|
| 78 |
+
for attempt in range(retries):
|
| 79 |
+
try:
|
| 80 |
+
req = Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
| 81 |
+
with urlopen(req) as response, open(save_path, 'wb') as f:
|
| 82 |
+
f.write(response.read())
|
| 83 |
+
print(f"Téléchargé : {save_path}")
|
| 84 |
+
return
|
| 85 |
+
except (HTTPError, URLError) as e:
|
| 86 |
+
print(f"Erreur ({e}) pour {url}, tentative {attempt+1}/{retries}")
|
| 87 |
+
time.sleep(delay)
|
| 88 |
+
print(f"Échec du téléchargement : {url}")
|
| 89 |
+
|
| 90 |
+
'''
|
| 91 |
+
Le chargement automatique des PDFs est désactivé. La base de données utilise les embeddings existants.
|
| 92 |
+
for url in PDF_URLS:
|
| 93 |
+
file_name = url.split("/")[-1]
|
| 94 |
+
file_path = os.path.join(SAVE_FOLDER, file_name)
|
| 95 |
+
if not os.path.exists(file_path):
|
| 96 |
+
download_pdf(url, file_path)
|
| 97 |
+
|
| 98 |
+
loader = PyPDFDirectoryLoader(SAVE_FOLDER)
|
| 99 |
+
docs = loader.load()
|
| 100 |
+
|
| 101 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
| 102 |
+
chunks = splitter.split_documents(docs)
|
| 103 |
+
print(f"{len(chunks)} morceaux extraits.")
|
| 104 |
+
|
| 105 |
+
embedding_model = HuggingFaceEmbeddings(model_name="shtilev/medical_embedded_v2")
|
| 106 |
+
|
| 107 |
+
client = MongoClient(MONGO_URI)
|
| 108 |
+
collection = client[DB_NAME][COLLECTION_NAME]
|
| 109 |
+
|
| 110 |
+
collection.delete_many({})
|
| 111 |
+
|
| 112 |
+
for chunk in chunks:
|
| 113 |
+
text = chunk.page_content
|
| 114 |
+
embedding = embedding_model.embed_query(text)
|
| 115 |
+
collection.insert_one({
|
| 116 |
+
"text": text,
|
| 117 |
+
"embedding": embedding
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
print("Tous les embeddings ont été insérés dans la base MongoDB.")
|
| 121 |
+
'''
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def retrieve_relevant_context(query, embedding_model, mongo_collection, k=5):
|
| 127 |
+
query_embedding = embedding_model.embed_query(query)
|
| 128 |
+
|
| 129 |
+
docs = list(mongo_collection.find({}, {"text": 1, "embedding": 1}))
|
| 130 |
+
|
| 131 |
+
# Format pour affichage de debug
|
| 132 |
+
print(f"[DEBUG] Recherche de contexte pour: '{query}'")
|
| 133 |
+
print(f"[DEBUG] {len(docs)} documents trouvés dans la base de données")
|
| 134 |
+
|
| 135 |
+
# Si pas de documents, retourner chaîne vide
|
| 136 |
+
if not docs:
|
| 137 |
+
print("[DEBUG] Aucun document dans la collection. RAG désactivé.")
|
| 138 |
+
return ""
|
| 139 |
+
|
| 140 |
+
# Calcul des similarités
|
| 141 |
+
similarities = []
|
| 142 |
+
for i, doc in enumerate(docs):
|
| 143 |
+
if "embedding" not in doc or not doc["embedding"]:
|
| 144 |
+
print(f"[DEBUG] Document {i} sans embedding")
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
sim = cosine_similarity([query_embedding], [doc["embedding"]])[0][0]
|
| 148 |
+
similarities.append((sim, i, doc["text"]))
|
| 149 |
+
|
| 150 |
+
# Trier par similarité décroissante
|
| 151 |
+
similarities.sort(reverse=True)
|
| 152 |
+
|
| 153 |
+
# Afficher les top k documents avec leurs scores
|
| 154 |
+
print("\n=== CONTEXTE SÉLECTIONNÉ ===")
|
| 155 |
+
top_k_docs = []
|
| 156 |
+
for i, (score, idx, text) in enumerate(similarities[:k]):
|
| 157 |
+
doc_preview = text[:100] + "..." if len(text) > 100 else text
|
| 158 |
+
print(f"Document #{i+1} (score: {score:.4f}): {doc_preview}")
|
| 159 |
+
top_k_docs.append(text)
|
| 160 |
+
print("==========================\n")
|
| 161 |
+
|
| 162 |
+
# Retourner le texte joint
|
| 163 |
+
return "\n\n".join(top_k_docs)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
|
| 167 |
async def get_admin_user(request: Request):
|
| 168 |
user = await get_current_user(request)
|
|
|
|
| 173 |
|
| 174 |
# Initialiser le modèle d'embedding (à faire une seule fois au démarrage)
|
| 175 |
try:
|
| 176 |
+
embedding_model = HuggingFaceEmbeddings(model_name="shtilev/medical_embedded_v2")
|
| 177 |
+
print("✅ Modèle d'embedding médical chargé avec succès")
|
| 178 |
+
|
| 179 |
except Exception as e:
|
| 180 |
print(f"Erreur lors du chargement du modèle d'embedding: {str(e)}")
|
| 181 |
+
embedding_model = None
|
| 182 |
+
|
| 183 |
+
doc_count = db.connaissances.count_documents({})
|
| 184 |
+
print(f"\n[DIAGNOSTIC] Collection 'connaissances': {doc_count} documents trouvés")
|
| 185 |
+
if doc_count == 0:
|
| 186 |
+
print("[AVERTISSEMENT] La collection est vide. Le système RAG ne fonctionnera pas!")
|
| 187 |
+
print("[AVERTISSEMENT] Veuillez charger des documents via l'API admin ou exécuter le script d'initialisation.")
|
| 188 |
+
else:
|
| 189 |
+
sample_doc = db.connaissances.find_one({})
|
| 190 |
+
has_embeddings = "embedding" in sample_doc and sample_doc["embedding"] is not None
|
| 191 |
+
print(f"[DIAGNOSTIC] Les documents ont des embeddings: {'✅ Oui' if has_embeddings else '❌ Non'}")
|
| 192 |
+
if not has_embeddings:
|
| 193 |
+
print("[AVERTISSEMENT] Les documents n'ont pas d'embeddings valides!")
|
| 194 |
@app.post("/api/admin/knowledge/upload")
|
| 195 |
async def upload_pdf(
|
| 196 |
file: UploadFile = File(...),
|
|
|
|
| 479 |
HF_TOKEN = os.getenv('REACT_APP_HF_TOKEN')
|
| 480 |
if not HF_TOKEN:
|
| 481 |
raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.")
|
| 482 |
+
conversation_history = {}
|
| 483 |
hf_client = InferenceClient(token=HF_TOKEN)
|
|
|
|
| 484 |
@app.post("/api/chat")
|
| 485 |
async def chat(request: Request):
|
| 486 |
+
global conversation_history
|
| 487 |
+
|
| 488 |
+
# ① Lecture du JSON et extraction des champs
|
| 489 |
data = await request.json()
|
| 490 |
user_message = data.get("message", "").strip()
|
| 491 |
+
conversation_id = data.get("conversation_id")
|
| 492 |
+
|
| 493 |
+
# ② Vérification du message utilisateur
|
| 494 |
if not user_message:
|
| 495 |
raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
|
| 496 |
|
| 497 |
+
# ③ Authentification (on continue même si non authentifié)
|
| 498 |
+
current_user = None
|
| 499 |
+
try:
|
| 500 |
+
current_user = await get_current_user(request)
|
| 501 |
+
except HTTPException:
|
| 502 |
+
pass
|
| 503 |
+
|
| 504 |
+
# ④ Gestion du quota de tokens pour l'utilisateur/authenticated convo
|
| 505 |
+
current_tokens = 0
|
| 506 |
+
message_tokens = 0
|
| 507 |
+
if current_user and conversation_id:
|
| 508 |
+
conv = db.conversations.find_one({
|
| 509 |
+
"_id": ObjectId(conversation_id),
|
| 510 |
+
"user_id": str(current_user["_id"])
|
| 511 |
+
})
|
| 512 |
+
if conv:
|
| 513 |
+
current_tokens = conv.get("token_count", 0)
|
| 514 |
+
message_tokens = int(len(user_message.split()) * 1.3)
|
| 515 |
+
MAX_TOKENS = 2000
|
| 516 |
+
if current_tokens + message_tokens > MAX_TOKENS:
|
| 517 |
+
return JSONResponse({
|
| 518 |
+
"error": "token_limit_exceeded",
|
| 519 |
+
"message": "Cette conversation a atteint sa limite de taille. Veuillez en créer une nouvelle.",
|
| 520 |
+
"tokens_used": current_tokens,
|
| 521 |
+
"tokens_limit": MAX_TOKENS
|
| 522 |
+
}, status_code=403)
|
| 523 |
+
|
| 524 |
+
# ⑤ Sauvegarde immédiate du message utilisateur
|
| 525 |
+
if conversation_id and current_user:
|
| 526 |
+
db.messages.insert_one({
|
| 527 |
+
"conversation_id": conversation_id,
|
| 528 |
+
"user_id": str(current_user["_id"]),
|
| 529 |
+
"sender": "user",
|
| 530 |
+
"text": user_message,
|
| 531 |
+
"timestamp": datetime.utcnow()
|
| 532 |
+
})
|
| 533 |
+
|
| 534 |
+
# ⑥ Détection d'une question sur l'historique
|
| 535 |
+
is_history_question = any(
|
| 536 |
+
phrase in user_message.lower()
|
| 537 |
+
for phrase in [
|
| 538 |
+
"ma première question", "ma précédente question", "ma dernière question",
|
| 539 |
+
"ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
|
| 540 |
+
"c'était quoi ma", "quelle était ma", "mes questions"
|
| 541 |
+
]
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# ⑦ Initialize conversation history if it doesn't exist
|
| 545 |
+
if conversation_id not in conversation_history:
|
| 546 |
+
conversation_history[conversation_id] = []
|
| 547 |
+
# If there's existing conversation in DB, load it to memory
|
| 548 |
+
if current_user and conversation_id:
|
| 549 |
+
previous_messages = list(db.messages.find(
|
| 550 |
+
{"conversation_id": conversation_id}
|
| 551 |
+
).sort("timestamp", 1))
|
| 552 |
+
|
| 553 |
+
for msg in previous_messages:
|
| 554 |
+
if msg["sender"] == "user":
|
| 555 |
+
conversation_history[conversation_id].append(f"Question : {msg['text']}")
|
| 556 |
+
else:
|
| 557 |
+
conversation_history[conversation_id].append(f"Réponse : {msg['text']}")
|
| 558 |
+
|
| 559 |
+
# ─── Gestion spécialisée des questions d'historique ─────────
|
| 560 |
+
if is_history_question:
|
| 561 |
+
# Recueillir les vraies questions (pas les méta-questions sur l'historique)
|
| 562 |
+
actual_questions = []
|
| 563 |
+
|
| 564 |
+
if conversation_id in conversation_history:
|
| 565 |
+
for msg in conversation_history[conversation_id]:
|
| 566 |
+
if msg.startswith("Question : "):
|
| 567 |
+
q_text = msg.replace("Question : ", "")
|
| 568 |
+
# Ignorer les méta-questions qui parlent déjà de l'historique
|
| 569 |
+
is_meta = any(phrase in q_text.lower() for phrase in [
|
| 570 |
+
"ma première question", "ma précédente question", "ma dernière question",
|
| 571 |
+
"ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
|
| 572 |
+
"c'était quoi ma", "quelle était ma", "mes questions"
|
| 573 |
+
])
|
| 574 |
+
if not is_meta:
|
| 575 |
+
actual_questions.append(q_text)
|
| 576 |
+
|
| 577 |
+
# Cas 1: Aucune question précédente
|
| 578 |
+
if not actual_questions:
|
| 579 |
+
return JSONResponse({
|
| 580 |
+
"response": "Vous n'avez pas encore posé de question dans cette conversation. C'est notre premier échange."
|
| 581 |
+
})
|
| 582 |
+
|
| 583 |
+
# Détection dynamique du numéro de question demandé
|
| 584 |
+
question_number = None
|
| 585 |
+
|
| 586 |
+
# Chercher les patterns de questions spécifiques
|
| 587 |
+
if any(p in user_message.lower() for p in ["première question", "1ère question", "1ere question"]):
|
| 588 |
+
question_number = 1
|
| 589 |
+
elif any(p in user_message.lower() for p in ["deuxième question", "2ème question", "2eme question", "seconde question"]):
|
| 590 |
+
question_number = 2
|
| 591 |
+
else:
|
| 592 |
+
import re
|
| 593 |
+
# Chercher des patterns comme "3ème question", "4e question", etc.
|
| 594 |
+
match = re.search(r'(\d+)[eèiéê]*m*e* question', user_message.lower())
|
| 595 |
+
if match:
|
| 596 |
+
try:
|
| 597 |
+
question_number = int(match.group(1))
|
| 598 |
+
except:
|
| 599 |
+
pass
|
| 600 |
+
|
| 601 |
+
# Si on a identifié un numéro de question spécifique
|
| 602 |
+
if question_number is not None:
|
| 603 |
+
if 0 < question_number <= len(actual_questions):
|
| 604 |
+
suffix = "ère" if question_number == 1 else "ème"
|
| 605 |
+
return JSONResponse({
|
| 606 |
+
"response": f"Votre {question_number}{suffix} question était : \"{actual_questions[question_number-1]}\""
|
| 607 |
+
})
|
| 608 |
+
else:
|
| 609 |
+
return JSONResponse({
|
| 610 |
+
"response": f"Vous n'avez pas encore posé {question_number} questions dans cette conversation."
|
| 611 |
+
})
|
| 612 |
+
|
| 613 |
+
# Cas général: liste toutes les questions
|
| 614 |
+
else:
|
| 615 |
+
if len(actual_questions) == 1:
|
| 616 |
+
return JSONResponse({
|
| 617 |
+
"response": f"Vous avez posé une seule question jusqu'à présent : \"{actual_questions[0]}\""
|
| 618 |
+
})
|
| 619 |
+
else:
|
| 620 |
+
question_list = "\n".join([f"{i+1}. {q}" for i, q in enumerate(actual_questions)])
|
| 621 |
+
return JSONResponse({
|
| 622 |
+
"response": f"Voici les questions que vous avez posées dans cette conversation :\n\n{question_list}"
|
| 623 |
+
})
|
| 624 |
+
# ───────────────────────────────────────────────────────────────
|
| 625 |
+
|
| 626 |
+
# ⑧ RAG – récupération de contexte si ce n'est pas une question d'historique
|
| 627 |
+
context = None
|
| 628 |
+
if not is_history_question and embedding_model:
|
| 629 |
+
context = retrieve_relevant_context(user_message, embedding_model, db.connaissances, k=5)
|
| 630 |
+
# Store context in history
|
| 631 |
+
if context and conversation_id:
|
| 632 |
+
conversation_history[conversation_id].append(f"Contexte : {context}")
|
| 633 |
+
|
| 634 |
+
# Add current question to history
|
| 635 |
+
if conversation_id:
|
| 636 |
+
conversation_history[conversation_id].append(f"Question : {user_message}")
|
| 637 |
+
|
| 638 |
+
# ⑨ Construction du prompt système avec contexte enrichi
|
| 639 |
+
system_prompt = (
|
| 640 |
+
"Tu es un chatbot spécialisé dans la santé mentale, et plus particulièrement la schizophrénie. "
|
| 641 |
+
"Tu réponds de façon fiable, claire et empathique, en t'appuyant uniquement sur des sources médicales et en français. "
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# Construire un contexte enrichi qui combine RAG et résumé de l'historique
|
| 645 |
+
enriched_context = ""
|
| 646 |
+
|
| 647 |
+
# Ajouter un résumé des questions précédentes (maximum 3)
|
| 648 |
+
if conversation_id in conversation_history:
|
| 649 |
+
actual_questions = []
|
| 650 |
+
for msg in conversation_history[conversation_id]:
|
| 651 |
+
if msg.startswith("Question : "):
|
| 652 |
+
q_text = msg.replace("Question : ", "")
|
| 653 |
+
# Ignorer les méta-questions
|
| 654 |
+
is_meta = any(phrase in q_text.lower() for phrase in [
|
| 655 |
+
"ma première question", "ma précédente question", "ma dernière question",
|
| 656 |
+
"ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
|
| 657 |
+
"c'était quoi ma", "quelle était ma", "mes questions"
|
| 658 |
+
])
|
| 659 |
+
if not is_meta and q_text != user_message: # Ne pas inclure la question actuelle
|
| 660 |
+
actual_questions.append(q_text)
|
| 661 |
+
|
| 662 |
+
# Ajouter les 3 dernières questions au contexte
|
| 663 |
+
if actual_questions:
|
| 664 |
+
recent_questions = actual_questions[-5:] # 3 dernières questions
|
| 665 |
+
enriched_context += "Historique récent des questions:\n"
|
| 666 |
+
for i, q in enumerate(recent_questions):
|
| 667 |
+
enriched_context += f"- Question précédente {len(recent_questions)-i}: {q}\n"
|
| 668 |
+
enriched_context += "\n"
|
| 669 |
+
|
| 670 |
+
# Ajouter le contexte RAG s'il existe
|
| 671 |
+
if context:
|
| 672 |
+
enriched_context += "Contexte médical pertinent:\n"
|
| 673 |
+
enriched_context += context
|
| 674 |
+
enriched_context += "\n\n"
|
| 675 |
+
|
| 676 |
+
# Compléter le prompt système
|
| 677 |
+
if enriched_context:
|
| 678 |
+
system_prompt += (
|
| 679 |
+
f"\n\n{enriched_context}\n\n"
|
| 680 |
+
"Utilise ces informations pour répondre de manière plus précise et contextuelle. "
|
| 681 |
+
"Ne pas inventer d'informations. Si tu ne sais pas, redirige vers un professionnel de santé."
|
| 682 |
+
)
|
| 683 |
+
else:
|
| 684 |
+
system_prompt += (
|
| 685 |
+
"Tu dois répondre uniquement à partir de connaissances médicales factuelles. "
|
| 686 |
+
"Si tu ne sais pas répondre, indique-le clairement et suggère de consulter un professionnel de santé."
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# ⑩ Construction de l'historique conversationnel pour le modèle
|
| 690 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 691 |
+
|
| 692 |
+
# Format conversation history for the LLM
|
| 693 |
+
if conversation_id and len(conversation_history.get(conversation_id, [])) > 0:
|
| 694 |
+
# Convert our history format to chat format (last 10 exchanges)
|
| 695 |
+
history = conversation_history[conversation_id]
|
| 696 |
+
for i in range(0, min(20, len(history)-1), 2):
|
| 697 |
+
if i+1 < len(history):
|
| 698 |
+
if history[i].startswith("Question :"):
|
| 699 |
+
user_text = history[i].replace("Question : ", "")
|
| 700 |
+
messages.append({"role": "user", "content": user_text})
|
| 701 |
+
|
| 702 |
+
if history[i+1].startswith("Réponse :"):
|
| 703 |
+
assistant_text = history[i+1].replace("Réponse : ", "")
|
| 704 |
+
messages.append({"role": "assistant", "content": assistant_text})
|
| 705 |
+
|
| 706 |
+
# Add current user message
|
| 707 |
+
messages.append({"role": "user", "content": user_message})
|
| 708 |
+
|
| 709 |
+
# ⑫ Appel à l'API Hugging Face
|
| 710 |
try:
|
| 711 |
+
completion = hf_client.chat.completions.create(
|
| 712 |
+
model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 713 |
+
messages=messages,
|
| 714 |
+
max_tokens=400,
|
| 715 |
+
temperature=0.7,
|
| 716 |
+
timeout=15,
|
| 717 |
+
)
|
| 718 |
+
bot_response = completion.choices[0].message["content"].strip()
|
| 719 |
+
except Exception:
|
| 720 |
+
fallback = hf_client.text_generation(
|
| 721 |
model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 722 |
+
prompt=f"<s>[INST] {system_prompt}\n\nQuestion: {user_message} [/INST]",
|
| 723 |
max_new_tokens=512,
|
| 724 |
temperature=0.7
|
| 725 |
)
|
| 726 |
+
bot_response = fallback
|
| 727 |
+
|
| 728 |
+
# Add bot response to history
|
| 729 |
+
if conversation_id:
|
| 730 |
+
conversation_history[conversation_id].append(f"Réponse : {bot_response}")
|
| 731 |
|
| 732 |
+
# Keep history to a reasonable size
|
| 733 |
+
if len(conversation_history[conversation_id]) > 50: # 25 exchanges
|
| 734 |
+
conversation_history[conversation_id] = conversation_history[conversation_id][-50:]
|
|
|
|
|
|
|
|
|
|
| 735 |
|
| 736 |
+
# ⑬ Sauvegarde de la réponse de l'assistant + mise à jour tokens & last_message
|
| 737 |
+
if conversation_id and current_user:
|
| 738 |
+
db.messages.insert_one({
|
| 739 |
+
"conversation_id": conversation_id,
|
| 740 |
+
"user_id": str(current_user["_id"]),
|
| 741 |
+
"sender": "assistant",
|
| 742 |
+
"text": bot_response,
|
| 743 |
+
"timestamp": datetime.utcnow()
|
| 744 |
+
})
|
| 745 |
+
response_tokens = int(len(bot_response.split()) * 1.3)
|
| 746 |
+
total_tokens = current_tokens + message_tokens + response_tokens
|
| 747 |
+
db.conversations.update_one(
|
| 748 |
+
{"_id": ObjectId(conversation_id)},
|
| 749 |
+
{"$set": {
|
| 750 |
+
"last_message": bot_response,
|
| 751 |
+
"updated_at": datetime.utcnow(),
|
| 752 |
+
"token_count": total_tokens
|
| 753 |
+
}}
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# ⑭ Retour de la réponse finale
|
| 757 |
+
return {"response": bot_response}
|
| 758 |
|
| 759 |
|
| 760 |
+
def simulate_token_count(text):
|
| 761 |
+
"""
|
| 762 |
+
Simule le comptage de tokens sans appeler d'API externe.
|
| 763 |
+
"""
|
| 764 |
+
if not text:
|
| 765 |
+
return 0
|
| 766 |
+
|
| 767 |
+
# Prétraitement pour mieux gérer les cas spéciaux
|
| 768 |
+
text = text.replace('\n', ' \n ')
|
| 769 |
+
|
| 770 |
+
# Compter les caractères spéciaux et espaces
|
| 771 |
+
spaces_and_punct = sum(1 for c in text if c.isspace() or c in ',.;:!?()[]{}"\'`-_=+<>/@#$%^&*|\\')
|
| 772 |
+
|
| 773 |
+
# Compter les chiffres
|
| 774 |
+
digits = sum(1 for c in text if c.isdigit())
|
| 775 |
+
|
| 776 |
+
# Compter les mots courts et tokens spéciaux
|
| 777 |
+
words = text.split()
|
| 778 |
+
short_words = sum(1 for w in words if len(w) <= 2)
|
| 779 |
+
|
| 780 |
+
# Les URLs et codes consomment plus de tokens
|
| 781 |
+
code_blocks = len(re.findall(r'```[\s\S]*?```', text))
|
| 782 |
+
urls = len(re.findall(r'https?://\S+', text))
|
| 783 |
+
|
| 784 |
+
# Longueur restante après ajustements
|
| 785 |
+
adjusted_length = len(text) - spaces_and_punct - digits - short_words
|
| 786 |
+
|
| 787 |
+
# Calcul final avec facteurs de pondération
|
| 788 |
+
token_count = (
|
| 789 |
+
adjusted_length / 4 +
|
| 790 |
+
spaces_and_punct * 0.25 +
|
| 791 |
+
digits * 0.5 +
|
| 792 |
+
short_words * 0.5 +
|
| 793 |
+
code_blocks * 5 +
|
| 794 |
+
urls * 4
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
return int(token_count * 1.1) + 1
|
| 798 |
@app.get("/data")
|
| 799 |
async def get_data():
|
| 800 |
data = {"data": np.random.rand(100).tolist()}
|
|
|
|
| 918 |
except Exception as e:
|
| 919 |
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
|
| 920 |
|
| 921 |
+
|
| 922 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
| 923 |
|
| 924 |
if __name__ == "__main__":
|
|
|
|
| 926 |
|
| 927 |
print(args)
|
| 928 |
uvicorn.run(
|
| 929 |
+
"app2:app",
|
| 930 |
host=args.host,
|
| 931 |
port=args.port,
|
| 932 |
reload=args.reload,
|
| 933 |
+
|
| 934 |
ssl_certfile=args.ssl_certfile,
|
| 935 |
ssl_keyfile=args.ssl_keyfile,
|
| 936 |
)
|
| 937 |
|
|
|