|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
import uuid |
|
|
import threading |
|
|
import traceback |
|
|
import logging |
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Request |
|
|
from fastapi.responses import HTMLResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from multi_doc_chat.src.document_ingestion.data_ingestion import ingest_upload_files |
|
|
from multi_doc_chat.rag_service import create_rag_service |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
from scripts.download_models import main as download_models_main |
|
|
try: |
|
|
download_models_main() |
|
|
except Exception as e: |
|
|
logger.warning("download_models_main() failed: %s\n%s", e, traceback.format_exc()) |
|
|
|
|
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parent |
|
|
FAISS_DIR = ROOT_DIR / "faiss_index" |
|
|
MODELS_DIR = ROOT_DIR / "models" |
|
|
|
|
|
|
|
|
rag_service = create_rag_service(faiss_dir=str(FAISS_DIR)) |
|
|
|
|
|
|
|
|
logger.info("RAGService initialized at %s", FAISS_DIR) |
|
|
logger.info("Number of documents in FAISS index: %d", len(rag_service.documents)) |
|
|
if rag_service.loader.embedder: |
|
|
logger.info("Embedding model loaded successfully") |
|
|
else: |
|
|
logger.warning("Embedding model not loaded") |
|
|
|
|
|
|
|
|
app = FastAPI(title="RAG Solution", version="0.1.0") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
static_dir = ROOT_DIR / "static" |
|
|
templates_dir = ROOT_DIR / "templates" |
|
|
if static_dir.exists(): |
|
|
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") |
|
|
templates = Jinja2Templates(directory=str(templates_dir)) |
|
|
|
|
|
STATE_LOCK = threading.Lock() |
|
|
SESSIONS: Dict[str, List[dict]] = {} |
|
|
|
|
|
|
|
|
class UploadResponse(BaseModel): |
|
|
session_id: str |
|
|
indexed: bool |
|
|
message: str | None = None |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
session_id: str |
|
|
message: str |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
answer: str |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return { |
|
|
"status": "ok", |
|
|
"faiss_docs": len(rag_service.documents), |
|
|
"model_loaded": rag_service.loader.embedder is not None |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
def home(request: Request): |
|
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
|
|
|
|
|
@app.post("/upload", response_model=UploadResponse) |
|
|
async def upload(files: List[UploadFile] = File(...)) -> UploadResponse: |
|
|
if not files: |
|
|
raise HTTPException(status_code=400, detail="No files uploaded") |
|
|
|
|
|
session_id = str(uuid.uuid4()) |
|
|
SESSIONS[session_id] = [] |
|
|
|
|
|
try: |
|
|
logger.info("Files received: %s", [f.filename for f in files]) |
|
|
|
|
|
|
|
|
with STATE_LOCK: |
|
|
await ingest_upload_files(files, rag_service) |
|
|
|
|
|
logger.info("Files ingested successfully for session %s", session_id) |
|
|
return UploadResponse( |
|
|
session_id=session_id, |
|
|
indexed=True, |
|
|
message=f"Ingested under session {session_id}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
logger.error("Upload failed: %s\n%s", e, tb) |
|
|
raise HTTPException(status_code=500, detail=f"Upload failed: {e}") |
|
|
|
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat(req: ChatRequest) -> ChatResponse: |
|
|
session_id = req.session_id or "" |
|
|
if not session_id or session_id not in SESSIONS: |
|
|
raise HTTPException(status_code=400, detail="Invalid or expired session_id. Re-upload documents.") |
|
|
message = (req.message or "").strip() |
|
|
if not message: |
|
|
raise HTTPException(status_code=400, detail="Message cannot be empty") |
|
|
|
|
|
try: |
|
|
logger.info("Chat request received: session=%s, message=%s", session_id, message) |
|
|
with STATE_LOCK: |
|
|
answer = rag_service.query(message) |
|
|
SESSIONS[session_id].append({"role": "user", "content": message}) |
|
|
SESSIONS[session_id].append({"role": "assistant", "content": answer}) |
|
|
logger.info("Chat response sent for session %s", session_id) |
|
|
return ChatResponse(answer=answer) |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
logger.error("Chat failed: %s\n%s", e, tb) |
|
|
raise HTTPException(status_code=500, detail=f"Chat failed: {e}") |
|
|
|
|
|
|
|
|
@app.get("/sessions") |
|
|
def list_sessions(): |
|
|
return {"sessions": list(SESSIONS.keys())} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
port = int(os.getenv("PORT", "8000")) |
|
|
uvicorn.run("main:app", host="0.0.0.0", port=port) |
|
|
|