Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| import pickle | |
| from app.utils import preprocess | |
| class BertForMultiLabel(nn.Module): | |
| def __init__(self, num_labels): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained('bert-base-uncased') | |
| self.dropout = nn.Dropout(0.3) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = self.dropout(outputs.pooler_output) | |
| logits = self.classifier(pooled_output) | |
| return logits | |
| def load_model(): | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| with open("app/mlb_classes.pkl", "rb") as f: | |
| classes = pickle.load(f) | |
| model = BertForMultiLabel(num_labels=len(classes)) | |
| model.load_state_dict(torch.load("app/bert_multilabel_model.pth", map_location="cpu")) | |
| model.eval() | |
| return model, tokenizer, classes | |
| def predict(text, model, tokenizer, mlb_classes, threshold=0.5): | |
| model.eval() | |
| inputs = preprocess(text, tokenizer) | |
| with torch.no_grad(): | |
| logits = model(**inputs) | |
| probs = torch.sigmoid(logits).squeeze() | |
| pred_labels = [mlb_classes[i] for i, prob in enumerate(probs) if prob >= threshold] | |
| return pred_labels | |