Spaces:
Sleeping
Sleeping
| # ====================================== | |
| # main.py – FastAPI pour chatbot douanier 🇹🇳 | |
| # ====================================== | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict | |
| import uuid | |
| import uvicorn | |
| # Import des modules internes | |
| from core.retrieval import find_relevant_articles | |
| from core.llm import generate_synthesized_llm_response_with_sources | |
| from core.memory import ( | |
| start_new_session, | |
| add_message_to_session, | |
| rename_session, | |
| update_session_title, | |
| get_messages_for_session, | |
| load_all_sessions | |
| ) | |
| # ====================================================== | |
| # CONFIGURATION DE L'APPLICATION | |
| # ====================================================== | |
| app = FastAPI(title="Chatbot Douane API 🇹🇳") | |
| # CORS (pour permettre les requêtes depuis un frontend) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| # ---------------------------- | |
| # Pydantic models | |
| # ---------------------------- | |
| class MessageData(BaseModel): | |
| type: Optional[str] = "human" | |
| content: str | |
| class AddMessageRequest(BaseModel): | |
| session_id: str | |
| message: MessageData | |
| class RenameSessionRequest(BaseModel): | |
| session_id: str | |
| new_title: str | |
| class NewSessionResponse(BaseModel): | |
| session_id: str | |
| class ChatRequest(BaseModel): | |
| question: str | |
| session_id: Optional[str] = None | |
| web_results: Optional[Dict] = None | |
| class ChatResponse(BaseModel): | |
| session_id: str | |
| answer: str | |
| articles_found: List[Dict] | |
| # ---------------------------- | |
| # Session endpoints | |
| # ---------------------------- | |
| def get_all_sessions(): | |
| """Return all sessions with their titles and creation date.""" | |
| sessions = load_all_sessions() | |
| return sessions | |
| def create_session(): | |
| """Create a new session and return its ID.""" | |
| session_state = {"sessions": {}} | |
| session_id = start_new_session(session_state) | |
| return {"session_id": session_id} | |
| def add_message(req: AddMessageRequest): | |
| """Add a message to a session.""" | |
| try: | |
| add_message_to_session(req.session_id, req.message.dict()) | |
| return {"status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def rename(req: RenameSessionRequest): | |
| """Rename a session manually.""" | |
| try: | |
| rename_session(req.session_id, req.new_title) | |
| return {"status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def update_title(session_id: str): | |
| """Update session title automatically based on first message.""" | |
| try: | |
| update_session_title(session_id) | |
| return {"status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def get_messages(session_id: str): | |
| """Get all messages for a given session.""" | |
| try: | |
| messages = get_messages_for_session(session_id) | |
| return messages | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| # ---------------------------- | |
| # Chatbot endpoint | |
| # ---------------------------- | |
| def chat_with_bot(request: ChatRequest): | |
| """ | |
| Envoie une question au chatbot douanier. | |
| - Recherche les articles pertinents | |
| - Génère la réponse à l’aide du LLM | |
| - Sauvegarde l’historique dans MongoDB | |
| """ | |
| try: | |
| # Si pas de session fourni, créer un nouvel ID | |
| session_id = request.session_id or f"session_{uuid.uuid4()}" | |
| # Récupérer les articles pertinents | |
| top_articles = find_relevant_articles(request.question) | |
| # Générer la réponse via LLM | |
| answer, _ = generate_synthesized_llm_response_with_sources( | |
| question=request.question, | |
| top_articles=top_articles, | |
| web_results=request.web_results or {}, | |
| session_id=session_id | |
| ) | |
| # Format des articles | |
| articles = [ | |
| { | |
| "article_num": doc.get("article_num"), | |
| "similarity": round(sim, 3) | |
| } | |
| for doc, sim in top_articles | |
| ] | |
| # Retour API | |
| return ChatResponse( | |
| session_id=session_id, | |
| answer=answer, | |
| articles_found=articles | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ====================================================== | |
| # LANCEMENT LOCAL | |
| # ====================================================== | |