ishmeet-yo's picture
Update app/main.py
fa81f98 verified
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from app.llm import generate_answer
from app.rag import load_data, retrieve_chunks
app = FastAPI()
# Static files and templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Load your RAG data once
# documents, embeddings= load_data()
documents = None
embeddings = None
@app.on_event("startup")
def startup_event():
global documents, embeddings
documents, embeddings = load_data()
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "answer": ""})
# @app.post("/ask", response_class=HTMLResponse)
# async def ask(request: Request, query: str = Form(...)):
# # Retrieve relevant chunks
# retrieved = retrieve_chunks(query, documents, embeddings)
# context = "\n\n".join(retrieved)
# # Ask the model
# answer = generate_answer(context, query)
# return templates.TemplateResponse(
# "index.html",
# {"request": request, "answer": answer, "query": query}
# )
@app.post("/ask")
async def ask(query: str = Form(...)):
retrieved = retrieve_chunks(query, documents, embeddings)
context = "\n\n".join(retrieved)
answer = generate_answer(context, query)
return {"answer": answer}