deberta.space / main.py
ganeshkonapalli's picture
Upload 6 files
0e73d34 verified
raw
history blame contribute delete
842 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import pickle
class InputText(BaseModel):
text: str
app = FastAPI()
# Load model, tokenizer, encoder
with open("app/model.pkl", "rb") as f:
model = pickle.load(f)
with open("app/tokenizer.pkl", "rb") as f:
tokenizer = pickle.load(f)
with open("app/label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
model.eval()
@app.get("/")
def read_root():
return {"message": "DeBERTa Model is live!"}
@app.post("/predict")
def predict(input: InputText):
inputs = tokenizer(input.text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
pred = torch.argmax(outputs.logits, dim=1).item()
label = label_encoder.inverse_transform([pred])[0]
return {"prediction": label}