subbunanepalli's picture
Create app/main.py
98fa3c3 verified
raw
history blame contribute delete
482 Bytes
from fastapi import FastAPI, Request
from pydantic import BaseModel
from app.model import load_model, predict
app = FastAPI()
model, tokenizer, mlb = load_model()
class PredictRequest(BaseModel):
text: str
@app.post("/predict")
async def predict_labels(request: PredictRequest):
text = request.text
pred_labels = predict(text, model, tokenizer, mlb)
return {"labels": pred_labels}
@app.get("/")
def root():
return {"message": "BERT Multi-label API is live"}