import os from fastapi import FastAPI from pydantic import BaseModel from transformers import BertTokenizer, BertForSequenceClassification from sklearn.preprocessing import LabelEncoder import torch import numpy as np # Set cache os.environ["TRANSFORMERS_CACHE"] = "/code/cache" app = FastAPI() model = BertForSequenceClassification.from_pretrained( "./bert-model" # or adjust path based on your structure ) tokenizer = BertTokenizer.from_pretrained( "./bert-model" ) model.eval() # Correct path to label_classes.npy label_path = os.path.join(os.path.dirname(__file__), "label_classes.npy") le = LabelEncoder() le.classes_ = np.load(label_path, allow_pickle=True) class TextInput(BaseModel): text: str @app.get("/") def read_root(): return {"message": "FastAPI backend is live. Go to /docs to test."} @app.post("/predict") async def predict(data: TextInput): inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) pred_class = torch.argmax(probs, dim=1).item() pred_label = le.classes_[pred_class] confidence = probs[0][pred_class].item() return {"predicted_category": pred_label, "confidence": confidence}