# ====================================== # main.py – FastAPI pour chatbot douanier 🇹🇳 # ====================================== from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Dict import uuid import uvicorn # Import des modules internes from core.retrieval import find_relevant_articles from core.llm import generate_synthesized_llm_response_with_sources from core.memory import ( start_new_session, add_message_to_session, rename_session, update_session_title, get_messages_for_session, load_all_sessions ) # ====================================================== # CONFIGURATION DE L'APPLICATION # ====================================================== app = FastAPI(title="Chatbot Douane API 🇹🇳") # CORS (pour permettre les requêtes depuis un frontend) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] ) # ---------------------------- # Pydantic models # ---------------------------- class MessageData(BaseModel): type: Optional[str] = "human" content: str class AddMessageRequest(BaseModel): session_id: str message: MessageData class RenameSessionRequest(BaseModel): session_id: str new_title: str class NewSessionResponse(BaseModel): session_id: str class ChatRequest(BaseModel): question: str session_id: Optional[str] = None web_results: Optional[Dict] = None class ChatResponse(BaseModel): session_id: str answer: str articles_found: List[Dict] # ---------------------------- # Session endpoints # ---------------------------- @app.get("/sessions") def get_all_sessions(): """Return all sessions with their titles and creation date.""" sessions = load_all_sessions() return sessions @app.post("/sessions/new", response_model=NewSessionResponse) def create_session(): """Create a new session and return its ID.""" session_state = {"sessions": {}} session_id = start_new_session(session_state) return {"session_id": session_id} @app.post("/sessions/add_message") def add_message(req: AddMessageRequest): """Add a message to a session.""" try: add_message_to_session(req.session_id, req.message.dict()) return {"status": "success"} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.post("/sessions/rename") def rename(req: RenameSessionRequest): """Rename a session manually.""" try: rename_session(req.session_id, req.new_title) return {"status": "success"} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.post("/sessions/update_title/{session_id}") def update_title(session_id: str): """Update session title automatically based on first message.""" try: update_session_title(session_id) return {"status": "success"} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/sessions/{session_id}/messages") def get_messages(session_id: str): """Get all messages for a given session.""" try: messages = get_messages_for_session(session_id) return messages except Exception as e: raise HTTPException(status_code=400, detail=str(e)) # ---------------------------- # Chatbot endpoint # ---------------------------- @app.post("/chat", response_model=ChatResponse) def chat_with_bot(request: ChatRequest): """ Envoie une question au chatbot douanier. - Recherche les articles pertinents - Génère la réponse à l’aide du LLM - Sauvegarde l’historique dans MongoDB """ try: # Si pas de session fourni, créer un nouvel ID session_id = request.session_id or f"session_{uuid.uuid4()}" # Récupérer les articles pertinents top_articles = find_relevant_articles(request.question) # Générer la réponse via LLM answer, _ = generate_synthesized_llm_response_with_sources( question=request.question, top_articles=top_articles, web_results=request.web_results or {}, session_id=session_id ) # Format des articles articles = [ { "article_num": doc.get("article_num"), "similarity": round(sim, 3) } for doc, sim in top_articles ] # Retour API return ChatResponse( session_id=session_id, answer=answer, articles_found=articles ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ====================================================== # LANCEMENT LOCAL # ======================================================