|
|
from fastapi import FastAPI
|
|
|
from pydantic import BaseModel
|
|
|
from transformers import BertTokenizer, BertForSequenceClassification
|
|
|
import torch
|
|
|
import pickle
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
with open("label_encoder.pkl", "rb") as f:
|
|
|
label_encoder = pickle.load(f)
|
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_))
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
class TextRequest(BaseModel):
|
|
|
text: str
|
|
|
|
|
|
@app.get("/")
|
|
|
def home():
|
|
|
return {"message": "Disease prediction API is running!"}
|
|
|
|
|
|
@app.post("/predict")
|
|
|
async def predict_endpoint(request: TextRequest):
|
|
|
|
|
|
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
|
|
with torch.no_grad():
|
|
|
outputs = model(**inputs)
|
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
|
|
|
|
|
|
|
|
|
labels = label_encoder.classes_
|
|
|
return {"predictions": dict(zip(labels, probs))}
|
|
|
|