gk / model_utils.py
ganeshkonapalli's picture
Upload 5 files
10c2ac1 verified
import pandas as pd
import torch
import pickle
import torch.nn as nn
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertModel
from torch.optim import AdamW
from tqdm import tqdm
TEXT_COLUMN = 'Sanction_Context'
LABEL_COLUMNS = [
'Red_Flag_Reason', 'Maker_Action', 'Escalation_Level',
'Risk_Category', 'Risk_Drivers', 'Investigation_Outcome'
]
PRETRAINED_MODEL_NAME = 'bert-base-uncased'
MAX_LEN = 128
BATCH_SIZE = 16
EPOCHS = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class BertMultiOutput(nn.Module):
def __init__(self, num_labels_per_output):
super().__init__()
self.bert = BertModel.from_pretrained(PRETRAINED_MODEL_NAME)
self.dropout = nn.Dropout(0.3)
self.classifiers = nn.ModuleList([
nn.Linear(self.bert.config.hidden_size, n_labels)
for n_labels in num_labels_per_output
])
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs.pooler_output)
return [classifier(pooled_output) for classifier in self.classifiers]
def train_and_save_model(csv_path, output_path='app/bert_model.pkl'):
df = pd.read_csv(csv_path)
X = df[TEXT_COLUMN].tolist()
y = df[LABEL_COLUMNS]
label_encoders = {}
y_encoded = pd.DataFrame()
for col in LABEL_COLUMNS:
le = LabelEncoder()
y_encoded[col] = le.fit_transform(y[col])
label_encoders[col] = le
X_train, _, y_train, _ = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
def tokenize_texts(texts):
return tokenizer(texts, padding=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
train_encodings = tokenize_texts(X_train)
labels = [torch.tensor(y_train[col].values) for col in LABEL_COLUMNS]
num_labels_list = [len(le.classes_) for le in label_encoders.values()]
model = BertMultiOutput(num_labels_list).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()
model.train()
for epoch in range(EPOCHS):
for i in tqdm(range(0, len(X_train), BATCH_SIZE)):
input_ids = train_encodings['input_ids'][i:i+BATCH_SIZE].to(DEVICE)
attention_mask = train_encodings['attention_mask'][i:i+BATCH_SIZE].to(DEVICE)
batch_labels = [label[i:i+BATCH_SIZE].to(DEVICE) for label in labels]
optimizer.zero_grad()
outputs = model(input_ids, attention_mask)
loss = sum([loss_fn(o, l) for o, l in zip(outputs, batch_labels)])
loss.backward()
optimizer.step()
model_bundle = {
'model_state_dict': model.state_dict(),
'tokenizer': tokenizer,
'label_encoders': label_encoders
}
with open(output_path, 'wb') as f:
pickle.dump(model_bundle, f)
def load_model(path='app/bert_model.pkl'):
with open(path, 'rb') as f:
bundle = pickle.load(f)
tokenizer = bundle['tokenizer']
label_encoders = bundle['label_encoders']
num_labels_list = [len(le.classes_) for le in label_encoders.values()]
model = BertMultiOutput(num_labels_list).to(DEVICE)
model.load_state_dict(bundle['model_state_dict'])
model.eval()
return model, tokenizer, label_encoders