import torch import torch.nn as nn from transformers import BertTokenizer, BertModel import pickle from app.utils import preprocess class BertForMultiLabel(nn.Module): def __init__(self, num_labels): super().__init__() self.bert = BertModel.from_pretrained('bert-base-uncased') self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = self.dropout(outputs.pooler_output) logits = self.classifier(pooled_output) return logits def load_model(): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") with open("app/mlb_classes.pkl", "rb") as f: classes = pickle.load(f) model = BertForMultiLabel(num_labels=len(classes)) model.load_state_dict(torch.load("app/bert_multilabel_model.pth", map_location="cpu")) model.eval() return model, tokenizer, classes def predict(text, model, tokenizer, mlb_classes, threshold=0.5): model.eval() inputs = preprocess(text, tokenizer) with torch.no_grad(): logits = model(**inputs) probs = torch.sigmoid(logits).squeeze() pred_labels = [mlb_classes[i] for i, prob in enumerate(probs) if prob >= threshold] return pred_labels