from fastapi import FastAPI from pydantic import BaseModel from transformers import BertTokenizerFast, BertModel import torch import torch.nn as nn import os # Define constants MODEL_PATH = os.path.join(os.path.dirname(__file__), "model") WEIGHTS_PATH = os.path.join(MODEL_PATH, "bert-multilabel-model.pth") NUM_LABELS = 6 # Adjust based on your dataset # Initialize FastAPI app app = FastAPI() # Load tokenizer from local directory tokenizer = BertTokenizerFast.from_pretrained(MODEL_PATH) # Define the BERT-based multi-label classifier class BertMultiLabelClassifier(nn.Module): def __init__(self): super(BertMultiLabelClassifier, self).__init__() self.bert = BertModel.from_pretrained(MODEL_PATH) self.classifier = nn.Linear(self.bert.config.hidden_size, NUM_LABELS) def forward(self, input_ids, attention_mask): output = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_output = output.last_hidden_state[:, 0, :] return self.classifier(cls_output) # Load the model weights model = BertMultiLabelClassifier() model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu")) model.eval() # Input schema for prediction class PredictRequest(BaseModel): text: str @app.get("/") def read_root(): return {"message": "Multi-label BERT model is running!"} @app.post("/predict") def predict(request: PredictRequest): inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=512) with torch.no_grad(): logits = model(**inputs) probs = torch.sigmoid(logits).squeeze().tolist() return {"probabilities": probs}