space / app.py
darshan-20072005's picture
Create app.py
3b2ce9a verified
raw
history blame contribute delete
540 Bytes
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI
import torch
app = FastAPI()
tokenizer = AutoTokenizer.from_pretrained("./model")
model = AutoModelForSequenceClassification.from_pretrained("./model")
@app.post("/predict")
def predict(text: str):
enc = tokenizer(text, return_tensors="pt")
with torch.no_grad():
out = model(**enc).logits
pred = torch.argmax(out, dim=1).item()
label = "productive" if pred == 1 else "unproductive"
return {"label": label}