subbunanepalli commited on
Commit
9175e50
·
verified ·
1 Parent(s): 98fa3c3

Create app/model.py

Browse files
Files changed (1) hide show
  1. app/model.py +37 -0
app/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer, BertModel
4
+ import pickle
5
+ from app.utils import preprocess
6
+
7
+ class BertForMultiLabel(nn.Module):
8
+ def __init__(self, num_labels):
9
+ super().__init__()
10
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
11
+ self.dropout = nn.Dropout(0.3)
12
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
13
+
14
+ def forward(self, input_ids, attention_mask):
15
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
16
+ pooled_output = self.dropout(outputs.pooler_output)
17
+ logits = self.classifier(pooled_output)
18
+ return logits
19
+
20
+ def load_model():
21
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
22
+ with open("app/mlb_classes.pkl", "rb") as f:
23
+ classes = pickle.load(f)
24
+
25
+ model = BertForMultiLabel(num_labels=len(classes))
26
+ model.load_state_dict(torch.load("app/bert_multilabel_model.pth", map_location="cpu"))
27
+ model.eval()
28
+ return model, tokenizer, classes
29
+
30
+ def predict(text, model, tokenizer, mlb_classes, threshold=0.5):
31
+ model.eval()
32
+ inputs = preprocess(text, tokenizer)
33
+ with torch.no_grad():
34
+ logits = model(**inputs)
35
+ probs = torch.sigmoid(logits).squeeze()
36
+ pred_labels = [mlb_classes[i] for i, prob in enumerate(probs) if prob >= threshold]
37
+ return pred_labels