File size: 2,477 Bytes
4e7e4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from pathlib import Path
import json
import faiss
import logging
from contextlib import asynccontextmanager

from src.embedding.embedding_client import EmbeddingClient
import os
from src.embedding.embeddings import EmbeddingModel
from src.vector_store.FAISS_store import VectorStore
from src.services.rag_service import RAGService

logger = logging.getLogger("rag-api")
logging.basicConfig(level=logging.INFO)


@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        logger.info("Starting RAG API...")

        faiss_path = Path("faiss.index")
        metadata_path = Path("metadata.json")

        if not faiss_path.exists() or not metadata_path.exists():
            logger.warning("FAISS or metadata not found. API will start in degraded mode.")
            app.state.rag_service = None
            yield
            return

        embedder = EmbeddingModel()
        index = faiss.read_index(str(faiss_path))
        store = VectorStore(dim=index.d)
        store.index = index

        with open(metadata_path, "r", encoding="utf-8") as f:
            store.metadata = json.load(f)

        app.state.store = store
        app.state.rag_service = RAGService(store, embedder)

        logger.info("RAG API ready.")
        yield

    except Exception as e:
        logger.exception("Startup failed")
        raise e


app = FastAPI(
    title="Lecture-Saver 3000 API",
    version="1.0.0",
    lifespan=lifespan,
)


class ChatRequest(BaseModel):
    question: str


@app.get("/health")
def health():
    return {
        "status": "ok",
        "rag_loaded": app.state.rag_service is not None
    }


@app.post("/ask")
async def ask(request: ChatRequest):
    if app.state.rag_service is None:
        raise HTTPException(
            status_code=503,
            detail="RAG service not initialized"
        )

    try:
        response = await app.state.rag_service.answer_question(request.question)
    except Exception as e:
        logger.exception("Unhandled RAG error")
        raise HTTPException(status_code=500, detail="Internal RAG failure")

    if response.get("status") == "error":
        status_code = 503 if response.get("error_type") == "LLMGenerationError" else 500
        raise HTTPException(status_code=status_code, detail=response.get("message"))

    return response