MedhaCodes's picture
Update api_app.py
fb8d240 verified
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
import os
app = FastAPI(title="QA Dashboard Pro")
MODEL_PATH = "MedhaCodes/qna_finetuned_model"
qa_pipeline = pipeline(
"question-answering",
model=AutoModelForQuestionAnswering.from_pretrained(MODEL_PATH),
tokenizer=AutoTokenizer.from_pretrained(MODEL_PATH)
)
# Mount static files (CSS, JS)
app.mount(
"/static",
StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")),
name="static"
)
# Load templates
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
context = data.get("context")
questions_text = data.get("question")
if not context or not questions_text:
return JSONResponse({"error": "Please provide both context and question"}, status_code=400)
questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
answers = []
for i, q in enumerate(questions, start=1):
try:
result = qa_pipeline(question=q, context=context)
answers.append({
"question": q,
"answer": result["answer"],
"score": round(result["score"], 4)
})
except Exception as e:
answers.append({"question": q, "answer": f"Error: {e}", "score": 0})
return {"results": answers}