from fastapi import FastAPI from pydantic import BaseModel from transformers import BertForSequenceClassification, AutoTokenizer from safetensors.torch import load_file import torch app = FastAPI() tokenizer = AutoTokenizer.from_pretrained("Reyall/nlp-disease-model") model = BertForSequenceClassification.from_pretrained("Reyall/nlp-disease-model") state_dict = load_file("model.safetensors") model.load_state_dict(state_dict) model.eval() class TextRequest(BaseModel): text: str @app.post("/predict") async def predict_endpoint(request: TextRequest): inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist() return {"probs": probs}