Legal-Chatbot / main.py
MISSAOUI's picture
Update main.py
0c728a5 verified
# ======================================
# main.py – FastAPI pour chatbot douanier 🇹🇳
# ======================================
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Dict
import uuid
import uvicorn
# Import des modules internes
from core.retrieval import find_relevant_articles
from core.llm import generate_synthesized_llm_response_with_sources
from core.memory import (
start_new_session,
add_message_to_session,
rename_session,
update_session_title,
get_messages_for_session,
load_all_sessions
)
# ======================================================
# CONFIGURATION DE L'APPLICATION
# ======================================================
app = FastAPI(title="Chatbot Douane API 🇹🇳")
# CORS (pour permettre les requêtes depuis un frontend)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
# ----------------------------
# Pydantic models
# ----------------------------
class MessageData(BaseModel):
type: Optional[str] = "human"
content: str
class AddMessageRequest(BaseModel):
session_id: str
message: MessageData
class RenameSessionRequest(BaseModel):
session_id: str
new_title: str
class NewSessionResponse(BaseModel):
session_id: str
class ChatRequest(BaseModel):
question: str
session_id: Optional[str] = None
web_results: Optional[Dict] = None
class ChatResponse(BaseModel):
session_id: str
answer: str
articles_found: List[Dict]
# ----------------------------
# Session endpoints
# ----------------------------
@app.get("/sessions")
def get_all_sessions():
"""Return all sessions with their titles and creation date."""
sessions = load_all_sessions()
return sessions
@app.post("/sessions/new", response_model=NewSessionResponse)
def create_session():
"""Create a new session and return its ID."""
session_state = {"sessions": {}}
session_id = start_new_session(session_state)
return {"session_id": session_id}
@app.post("/sessions/add_message")
def add_message(req: AddMessageRequest):
"""Add a message to a session."""
try:
add_message_to_session(req.session_id, req.message.dict())
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/sessions/rename")
def rename(req: RenameSessionRequest):
"""Rename a session manually."""
try:
rename_session(req.session_id, req.new_title)
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/sessions/update_title/{session_id}")
def update_title(session_id: str):
"""Update session title automatically based on first message."""
try:
update_session_title(session_id)
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/sessions/{session_id}/messages")
def get_messages(session_id: str):
"""Get all messages for a given session."""
try:
messages = get_messages_for_session(session_id)
return messages
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# ----------------------------
# Chatbot endpoint
# ----------------------------
@app.post("/chat", response_model=ChatResponse)
def chat_with_bot(request: ChatRequest):
"""
Envoie une question au chatbot douanier.
- Recherche les articles pertinents
- Génère la réponse à l’aide du LLM
- Sauvegarde l’historique dans MongoDB
"""
try:
# Si pas de session fourni, créer un nouvel ID
session_id = request.session_id or f"session_{uuid.uuid4()}"
# Récupérer les articles pertinents
top_articles = find_relevant_articles(request.question)
# Générer la réponse via LLM
answer, _ = generate_synthesized_llm_response_with_sources(
question=request.question,
top_articles=top_articles,
web_results=request.web_results or {},
session_id=session_id
)
# Format des articles
articles = [
{
"article_num": doc.get("article_num"),
"similarity": round(sim, 3)
}
for doc, sim in top_articles
]
# Retour API
return ChatResponse(
session_id=session_id,
answer=answer,
articles_found=articles
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ======================================================
# LANCEMENT LOCAL
# ======================================================