roberta / app.py
subbu123456's picture
Upload 4 files
0a5bbcc verified
raw
history blame contribute delete
599 Bytes
from fastapi import FastAPI, Request
from pydantic import BaseModel
import joblib
app = FastAPI()
# Load the model
with open("roberta_model.pkl", "rb") as f:
model = joblib.load(f)
class InputText(BaseModel):
text: str
@app.post("/predict")
def predict(data: InputText):
input_text = data.text
inputs = model["tokenizer"](input_text, return_tensors="pt", padding=True, truncation=True)
with model["torch"].no_grad():
outputs = model["model"](**inputs)
logits = outputs.logits
predictions = logits.argmax(dim=1).tolist()
return {"predictions": predictions}