BERT-BASED-NEWS-CLASSIFICATION / DL2_BERT_Model_Based_Classification.py
jrmd's picture
Initial commit
60ef55c
import re
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import (
accuracy_score,
confusion_matrix,
precision_score,
recall_score,
)
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import AutoTokenizer, BertModel
import wandb
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 5
BATCH_SIZE = 16
SAVED_MODEL_PATH = "custom_bert_model.torch"
SAVED_TARGET_CAT_PATH = "bbc-news-categories.torch"
DS_PATH = "bbc-news-data.csv"
from typing import DefaultDict
class CustomBertDataset(Dataset):
def __init__(
self,
file_path,
model_path="google-bert/bert-base-uncased",
saved_target_cats_path=SAVED_TARGET_CAT_PATH,
):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.lines = open(file_path).readlines()
self.lines = np.array(
[
[
re.split(r"\t+", line.replace("\n", ""))[3],
re.split(r"\t+", line.replace("\n", ""))[0],
]
for i, line in enumerate(self.lines)
if line != "\n" and i != 0
]
)
self.corpus = np.array(self.lines[:, 0])
self.elem_cats = self.lines[:, 1]
self.unique_cats = sorted(list(set(self.elem_cats)))
self.num_class = len(self.unique_cats)
self.cats_dict = {cat: i for i, cat in enumerate(self.unique_cats)}
self.targets = np.array([self.cats_dict[cat] for cat in self.elem_cats])
torch.save(self.unique_cats, saved_target_cats_path)
entry_dict = DefaultDict(list)
for i in range(len(self.corpus)):
entry_dict[self.targets[i]].append(self.corpus[i])
self.final_corpus = []
self.final_targets = []
n = 0
while n < len(self.corpus):
for key in entry_dict.keys():
if len(entry_dict[key]) > 0:
self.final_corpus.append(entry_dict[key].pop(0))
self.final_targets.append(key)
n += 1
self.corpus = np.array(self.final_corpus)
self.targets = np.array(self.final_targets)
self.max_len = 0
for sent in self.corpus:
input_ids = self.tokenizer.encode(sent, add_special_tokens=True)
self.max_len = max(self.max_len, len(input_ids))
self.max_len = min(self.max_len, 512)
print(f"Max length : {self.max_len}")
def __len__(self):
return len(self.corpus)
def __getitem__(self, idx):
text = self.corpus[idx]
target = self.targets[idx]
encoded_input = self.tokenizer.encode_plus(
text,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return (
encoded_input["input_ids"].squeeze(0),
encoded_input["attention_mask"].squeeze(0),
torch.tensor(target, dtype=torch.long),
)
# return np.array(encoded_input), torch.tensor(target, dtype=torch.long)
class CustomBertModel(nn.Module):
def __init__(self, num_class, model_path="google-bert/bert-base-uncased"):
super(CustomBertModel, self).__init__()
self.model_path = model_path
self.num_class = num_class
self.bert = BertModel.from_pretrained(self.model_path)
# Freeze of the parameters of this layer for the training process
for param in self.bert.parameters():
param.requires_grad = False
self.proj_lin = nn.Linear(self.bert.config.hidden_size, self.num_class)
def forward(self, input_ids, attention_mask):
x = self.bert(input_ids=input_ids, attention_mask=attention_mask)
x = x.last_hidden_state[:, 0, :]
x = self.proj_lin(x)
return x
def train_step(model, train_dataloader, loss_fn, optimizer):
num_iterations = len(train_dataloader)
for i in range(NUM_EPOCHS):
print(f"Training Epoch n° {i}")
model.train()
for j, batch in enumerate(train_dataloader):
input = batch[:][0]
attention = batch[:][1]
target = batch[:][2]
output = model(input.to(device), attention.to(device))
loss = loss_fn(output, target.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
run.log({"Training loss": loss})
print(f"Epoch {i+1} | step {j+1} / {num_iterations} | loss : {loss}")
# Save model
torch.save(model.state_dict(), SAVED_MODEL_PATH)
print(f"Model saved at {SAVED_MODEL_PATH}")
def eval_step(
test_dataloader,
loss_fn,
num_class,
saved_model_path=SAVED_MODEL_PATH,
saved_target_cats_path=SAVED_TARGET_CAT_PATH,
):
y_pred = []
y_true = []
num_iterations = len(test_dataloader)
# Load the saved model
saved_model = CustomBertModel(num_class)
saved_model.load_state_dict(
torch.load(saved_model_path, weights_only=False)
) # Explicitly set weights_only to False
saved_model = saved_model.to(device)
saved_model.eval() # Set the model to evaluation mode
print(f"Model loaded from path :{saved_model_path}")
with torch.no_grad():
for j, batch in enumerate(test_dataloader):
input = batch[:][0]
attention = batch[:][1]
target = batch[:][2]
output = saved_model(input.to(device), attention.to(device))
loss = loss_fn(output, target.to(device))
run.log({"Eval loss": loss})
print(f"Eval loss : {loss}")
y_pred.extend(output.cpu().numpy().argmax(axis=1))
y_true.extend(target.cpu().numpy())
class_labels = torch.load(saved_target_cats_path, weights_only=False)
true_labels = [class_labels[i] for i in y_true]
pred_labels = [class_labels[i] for i in y_pred]
print(f"Accuracy : {accuracy_score(true_labels, pred_labels)}")
cm = confusion_matrix(true_labels, pred_labels, labels=class_labels)
df_cm = pd.DataFrame(cm, index=class_labels, columns=class_labels)
sns.heatmap(df_cm, annot=True, fmt="d")
plt.title("Confusion Matrix for BBC News Dataset")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.show()
if __name__ == "__main__":
wandb.login()
run = wandb.init(project="DIT-Bert-bbc-news-project")
our_bert_dataset = CustomBertDataset(DS_PATH)
print(f"Size of bert dataset : {len(our_bert_dataset)}")
train_dataset = Subset(our_bert_dataset, range(int(len(our_bert_dataset) * 0.8)))
test_dataset = Subset(
our_bert_dataset, range(int(len(our_bert_dataset) * 0.8), len(our_bert_dataset))
)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
our_bert_model = CustomBertModel(our_bert_dataset.num_class)
our_bert_model = our_bert_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, our_bert_model.parameters()), lr=0.01
)
train_step(our_bert_model, train_dataloader, loss_fn, optimizer)
eval_step(test_dataloader, loss_fn, our_bert_dataset.num_class)