File size: 2,794 Bytes
ef5f450
 
00e4869
 
 
 
 
ef5f450
00e4869
 
 
ef5f450
 
 
 
 
c8b552c
ef5f450
 
 
 
 
 
 
 
 
 
 
 
 
 
00e4869
 
 
ef5f450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00e4869
2a847ac
ef5f450
 
 
 
 
 
 
00e4869
 
 
 
 
 
 
 
ef5f450
 
 
00e4869
 
 
 
 
 
 
 
 
 
ef5f450
 
2a847ac
ef5f450
 
00e4869
 
 
 
 
 
 
 
 
 
 
 
 
 
2a847ac
00e4869
ef5f450
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import logging
import time
from contextlib import asynccontextmanager
from typing import Optional

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel

from guardrails import validate_input, validate_output
from rag import RAGChain

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

rag_chain: Optional[RAGChain] = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global rag_chain
    logger.info("Loading RAG chain...")
    rag_chain = RAGChain()
    rag_chain.load()
    logger.info("RAG chain ready.")
    yield
    logger.info("Shutting down.")


app = FastAPI(
    title="Irminsul — Genshin Impact AI Assistant",
    description="RAG-powered assistant for Genshin Impact lore, builds, and mechanics.",
    version="2.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


class GenerateRequest(BaseModel):
    query: str
    top_k: int = 3


class GenerateResponse(BaseModel):
    answer: str
    sources: list[str]
    latency_ms: float
    blocked: bool = False
    retrieval_method: str = "rag"  # "rag" | "web_fallback" | "guardrail_blocked"


@app.get("/")
def ui():
    return FileResponse("index.html")


@app.get("/health")
def health():
    return {
        "status": "ok",
        "model_loaded": rag_chain is not None and rag_chain.ready,
    }


@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
    if not rag_chain or not rag_chain.ready:
        raise HTTPException(status_code=503, detail="Model not loaded yet.")

    allowed, reason = validate_input(req.query)
    if not allowed:
        return GenerateResponse(
            answer=reason,
            sources=[],
            latency_ms=0.0,
            blocked=True,
        )

    start = time.time()
    answer, sources, retrieval_method = rag_chain.query(req.query, top_k=req.top_k)
    latency_ms = (time.time() - start) * 1000

    is_clean, answer = validate_output(answer)
    if not is_clean:
        return GenerateResponse(
            answer=answer,
            sources=[],
            latency_ms=round(latency_ms, 1),
            blocked=True,
        )

    return GenerateResponse(
        answer=answer,
        sources=sources,
        latency_ms=round(latency_ms, 1),
        blocked=False,
        retrieval_method=retrieval_method,
    )


@app.post("/ingest")
def ingest(directory: str = "./docs"):
    """Ingest documents from a local directory into Pinecone."""
    from ingest import ingest_documents
    count = ingest_documents(directory)
    return {"ingested": count, "directory": directory}