Spaces:
Build error
Build error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import BertTokenizerFast, BertModel | |
| import torch | |
| import torch.nn as nn | |
| import os | |
| # Define constants | |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), "model") | |
| WEIGHTS_PATH = os.path.join(MODEL_PATH, "bert-multilabel-model.pth") | |
| NUM_LABELS = 6 # Adjust based on your dataset | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Load tokenizer from local directory | |
| tokenizer = BertTokenizerFast.from_pretrained(MODEL_PATH) | |
| # Define the BERT-based multi-label classifier | |
| class BertMultiLabelClassifier(nn.Module): | |
| def __init__(self): | |
| super(BertMultiLabelClassifier, self).__init__() | |
| self.bert = BertModel.from_pretrained(MODEL_PATH) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, NUM_LABELS) | |
| def forward(self, input_ids, attention_mask): | |
| output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_output = output.last_hidden_state[:, 0, :] | |
| return self.classifier(cls_output) | |
| # Load the model weights | |
| model = BertMultiLabelClassifier() | |
| model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu")) | |
| model.eval() | |
| # Input schema for prediction | |
| class PredictRequest(BaseModel): | |
| text: str | |
| def read_root(): | |
| return {"message": "Multi-label BERT model is running!"} | |
| def predict(request: PredictRequest): | |
| inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| with torch.no_grad(): | |
| logits = model(**inputs) | |
| probs = torch.sigmoid(logits).squeeze().tolist() | |
| return {"probabilities": probs} | |