Reyall's picture
Upload 2 files
a99556d verified
raw
history blame contribute delete
857 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertForSequenceClassification, AutoTokenizer
from safetensors.torch import load_file
import torch
app = FastAPI()
tokenizer = AutoTokenizer.from_pretrained("Reyall/nlp-disease-model")
model = BertForSequenceClassification.from_pretrained("Reyall/nlp-disease-model")
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
class TextRequest(BaseModel):
text: str
@app.post("/predict")
async def predict_endpoint(request: TextRequest):
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
return {"probs": probs}