compliance-data / app /model.py
subbunanepalli's picture
Create app/model.py
9175e50 verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import pickle
from app.utils import preprocess
class BertForMultiLabel(nn.Module):
def __init__(self, num_labels):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs.pooler_output)
logits = self.classifier(pooled_output)
return logits
def load_model():
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
with open("app/mlb_classes.pkl", "rb") as f:
classes = pickle.load(f)
model = BertForMultiLabel(num_labels=len(classes))
model.load_state_dict(torch.load("app/bert_multilabel_model.pth", map_location="cpu"))
model.eval()
return model, tokenizer, classes
def predict(text, model, tokenizer, mlb_classes, threshold=0.5):
model.eval()
inputs = preprocess(text, tokenizer)
with torch.no_grad():
logits = model(**inputs)
probs = torch.sigmoid(logits).squeeze()
pred_labels = [mlb_classes[i] for i, prob in enumerate(probs) if prob >= threshold]
return pred_labels