Bert-model-test / predict.py
ganeshkonapalli's picture
Upload 8 files
8505a58 verified
raw
history blame contribute delete
333 Bytes
from app.model_utils import load_model
import torch
def predict_text(text: str):
model, tokenizer = load_model()
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs)
pred = torch.argmax(output.logits, dim=1).item()
return pred