from fastapi import FastAPI, Request, Form from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from app.llm import generate_answer from app.rag import load_data, retrieve_chunks app = FastAPI() # Static files and templates app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # Load your RAG data once # documents, embeddings= load_data() documents = None embeddings = None @app.on_event("startup") def startup_event(): global documents, embeddings documents, embeddings = load_data() @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse("index.html", {"request": request, "answer": ""}) # @app.post("/ask", response_class=HTMLResponse) # async def ask(request: Request, query: str = Form(...)): # # Retrieve relevant chunks # retrieved = retrieve_chunks(query, documents, embeddings) # context = "\n\n".join(retrieved) # # Ask the model # answer = generate_answer(context, query) # return templates.TemplateResponse( # "index.html", # {"request": request, "answer": answer, "query": query} # ) @app.post("/ask") async def ask(query: str = Form(...)): retrieved = retrieve_chunks(query, documents, embeddings) context = "\n\n".join(retrieved) answer = generate_answer(context, query) return {"answer": answer}