vvvvvv / src /api.py
Reyall's picture
Upload 9 files
84187cf verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import pickle
app = FastAPI()
# Label encoder yüklənməsi
with open("label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
# Model və tokenizer yüklənməsi
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_))
model.eval()
# Request modeli
class TextRequest(BaseModel):
text: str
@app.get("/")
def home():
return {"message": "Disease prediction API is running!"}
@app.post("/predict")
async def predict_endpoint(request: TextRequest):
# Tokenize giriş mətni
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
# Label-ları geri çevir
labels = label_encoder.classes_ # 'classes_' ilə etiketləri alırıq
return {"predictions": dict(zip(labels, probs))}