File size: 2,794 Bytes
ef5f450 00e4869 ef5f450 00e4869 ef5f450 c8b552c ef5f450 00e4869 ef5f450 00e4869 2a847ac ef5f450 00e4869 ef5f450 00e4869 ef5f450 2a847ac ef5f450 00e4869 2a847ac 00e4869 ef5f450 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import logging
import time
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from guardrails import validate_input, validate_output
from rag import RAGChain
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
rag_chain: Optional[RAGChain] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag_chain
logger.info("Loading RAG chain...")
rag_chain = RAGChain()
rag_chain.load()
logger.info("RAG chain ready.")
yield
logger.info("Shutting down.")
app = FastAPI(
title="Irminsul — Genshin Impact AI Assistant",
description="RAG-powered assistant for Genshin Impact lore, builds, and mechanics.",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class GenerateRequest(BaseModel):
query: str
top_k: int = 3
class GenerateResponse(BaseModel):
answer: str
sources: list[str]
latency_ms: float
blocked: bool = False
retrieval_method: str = "rag" # "rag" | "web_fallback" | "guardrail_blocked"
@app.get("/")
def ui():
return FileResponse("index.html")
@app.get("/health")
def health():
return {
"status": "ok",
"model_loaded": rag_chain is not None and rag_chain.ready,
}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
if not rag_chain or not rag_chain.ready:
raise HTTPException(status_code=503, detail="Model not loaded yet.")
allowed, reason = validate_input(req.query)
if not allowed:
return GenerateResponse(
answer=reason,
sources=[],
latency_ms=0.0,
blocked=True,
)
start = time.time()
answer, sources, retrieval_method = rag_chain.query(req.query, top_k=req.top_k)
latency_ms = (time.time() - start) * 1000
is_clean, answer = validate_output(answer)
if not is_clean:
return GenerateResponse(
answer=answer,
sources=[],
latency_ms=round(latency_ms, 1),
blocked=True,
)
return GenerateResponse(
answer=answer,
sources=sources,
latency_ms=round(latency_ms, 1),
blocked=False,
retrieval_method=retrieval_method,
)
@app.post("/ingest")
def ingest(directory: str = "./docs"):
"""Ingest documents from a local directory into Pinecone."""
from ingest import ingest_documents
count = ingest_documents(directory)
return {"ingested": count, "directory": directory} |