LeonardoMdSA's picture
added unit tests and integration tests
0c6c6fd
# main.py
from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, List
import uuid
import threading
import traceback
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from multi_doc_chat.src.document_ingestion.data_ingestion import ingest_upload_files
from multi_doc_chat.rag_service import create_rag_service
# --- Configure logging ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# --- Ensure models downloaded ---
from scripts.download_models import main as download_models_main
try:
download_models_main()
except Exception as e:
logger.warning("download_models_main() failed: %s\n%s", e, traceback.format_exc())
# --- ROOT paths for FAISS and models ---
ROOT_DIR = Path(__file__).resolve().parent
FAISS_DIR = ROOT_DIR / "faiss_index"
MODELS_DIR = ROOT_DIR / "models"
# --- Create RAGService ---
rag_service = create_rag_service(faiss_dir=str(FAISS_DIR))
# --- Startup logging for monitoring ---
logger.info("RAGService initialized at %s", FAISS_DIR)
logger.info("Number of documents in FAISS index: %d", len(rag_service.documents))
if rag_service.loader.embedder:
logger.info("Embedding model loaded successfully")
else:
logger.warning("Embedding model not loaded")
# --- FastAPI app ---
app = FastAPI(title="RAG Solution", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
static_dir = ROOT_DIR / "static"
templates_dir = ROOT_DIR / "templates"
if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
templates = Jinja2Templates(directory=str(templates_dir))
STATE_LOCK = threading.Lock()
SESSIONS: Dict[str, List[dict]] = {}
# --- Pydantic models ---
class UploadResponse(BaseModel):
session_id: str
indexed: bool
message: str | None = None
class ChatRequest(BaseModel):
session_id: str
message: str
class ChatResponse(BaseModel):
answer: str
# --- Health endpoint ---
@app.get("/health")
def health():
return {
"status": "ok",
"faiss_docs": len(rag_service.documents),
"model_loaded": rag_service.loader.embedder is not None
}
# --- Home endpoint ---
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# --- Upload endpoint ---
@app.post("/upload", response_model=UploadResponse)
async def upload(files: List[UploadFile] = File(...)) -> UploadResponse:
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
session_id = str(uuid.uuid4())
SESSIONS[session_id] = []
try:
logger.info("Files received: %s", [f.filename for f in files])
# ingest the files using the shared rag_service
with STATE_LOCK:
await ingest_upload_files(files, rag_service)
logger.info("Files ingested successfully for session %s", session_id)
return UploadResponse(
session_id=session_id,
indexed=True,
message=f"Ingested under session {session_id}"
)
except Exception as e:
tb = traceback.format_exc()
logger.error("Upload failed: %s\n%s", e, tb)
raise HTTPException(status_code=500, detail=f"Upload failed: {e}")
# --- Chat endpoint ---
@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest) -> ChatResponse:
session_id = req.session_id or ""
if not session_id or session_id not in SESSIONS:
raise HTTPException(status_code=400, detail="Invalid or expired session_id. Re-upload documents.")
message = (req.message or "").strip()
if not message:
raise HTTPException(status_code=400, detail="Message cannot be empty")
try:
logger.info("Chat request received: session=%s, message=%s", session_id, message)
with STATE_LOCK:
answer = rag_service.query(message)
SESSIONS[session_id].append({"role": "user", "content": message})
SESSIONS[session_id].append({"role": "assistant", "content": answer})
logger.info("Chat response sent for session %s", session_id)
return ChatResponse(answer=answer)
except Exception as e:
tb = traceback.format_exc()
logger.error("Chat failed: %s\n%s", e, tb)
raise HTTPException(status_code=500, detail=f"Chat failed: {e}")
# --- List sessions ---
@app.get("/sessions")
def list_sessions():
return {"sessions": list(SESSIONS.keys())}
# --- Run server ---
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "8000"))
uvicorn.run("main:app", host="0.0.0.0", port=port)