bert-fastapi / app /main.py
JaySenpai's picture
fixed issues
1baa976 verified
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.preprocessing import LabelEncoder
import torch
import numpy as np
# Set cache
os.environ["TRANSFORMERS_CACHE"] = "/code/cache"
app = FastAPI()
model = BertForSequenceClassification.from_pretrained(
"./bert-model" # or adjust path based on your structure
)
tokenizer = BertTokenizer.from_pretrained(
"./bert-model"
)
model.eval()
# Correct path to label_classes.npy
label_path = os.path.join(os.path.dirname(__file__), "label_classes.npy")
le = LabelEncoder()
le.classes_ = np.load(label_path, allow_pickle=True)
class TextInput(BaseModel):
text: str
@app.get("/")
def read_root():
return {"message": "FastAPI backend is live. Go to /docs to test."}
@app.post("/predict")
async def predict(data: TextInput):
inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_class = torch.argmax(probs, dim=1).item()
pred_label = le.classes_[pred_class]
confidence = probs[0][pred_class].item()
return {"predicted_category": pred_label, "confidence": confidence}