import re import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import ( accuracy_score, confusion_matrix, precision_score, recall_score, ) from torch.utils.data import DataLoader, Dataset, Subset from transformers import AutoTokenizer, BertModel import wandb device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") NUM_EPOCHS = 5 BATCH_SIZE = 16 SAVED_MODEL_PATH = "custom_bert_model.torch" SAVED_TARGET_CAT_PATH = "bbc-news-categories.torch" DS_PATH = "bbc-news-data.csv" from typing import DefaultDict class CustomBertDataset(Dataset): def __init__( self, file_path, model_path="google-bert/bert-base-uncased", saved_target_cats_path=SAVED_TARGET_CAT_PATH, ): self.model_path = model_path self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) self.lines = open(file_path).readlines() self.lines = np.array( [ [ re.split(r"\t+", line.replace("\n", ""))[3], re.split(r"\t+", line.replace("\n", ""))[0], ] for i, line in enumerate(self.lines) if line != "\n" and i != 0 ] ) self.corpus = np.array(self.lines[:, 0]) self.elem_cats = self.lines[:, 1] self.unique_cats = sorted(list(set(self.elem_cats))) self.num_class = len(self.unique_cats) self.cats_dict = {cat: i for i, cat in enumerate(self.unique_cats)} self.targets = np.array([self.cats_dict[cat] for cat in self.elem_cats]) torch.save(self.unique_cats, saved_target_cats_path) entry_dict = DefaultDict(list) for i in range(len(self.corpus)): entry_dict[self.targets[i]].append(self.corpus[i]) self.final_corpus = [] self.final_targets = [] n = 0 while n < len(self.corpus): for key in entry_dict.keys(): if len(entry_dict[key]) > 0: self.final_corpus.append(entry_dict[key].pop(0)) self.final_targets.append(key) n += 1 self.corpus = np.array(self.final_corpus) self.targets = np.array(self.final_targets) self.max_len = 0 for sent in self.corpus: input_ids = self.tokenizer.encode(sent, add_special_tokens=True) self.max_len = max(self.max_len, len(input_ids)) self.max_len = min(self.max_len, 512) print(f"Max length : {self.max_len}") def __len__(self): return len(self.corpus) def __getitem__(self, idx): text = self.corpus[idx] target = self.targets[idx] encoded_input = self.tokenizer.encode_plus( text, max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt", ) return ( encoded_input["input_ids"].squeeze(0), encoded_input["attention_mask"].squeeze(0), torch.tensor(target, dtype=torch.long), ) # return np.array(encoded_input), torch.tensor(target, dtype=torch.long) class CustomBertModel(nn.Module): def __init__(self, num_class, model_path="google-bert/bert-base-uncased"): super(CustomBertModel, self).__init__() self.model_path = model_path self.num_class = num_class self.bert = BertModel.from_pretrained(self.model_path) # Freeze of the parameters of this layer for the training process for param in self.bert.parameters(): param.requires_grad = False self.proj_lin = nn.Linear(self.bert.config.hidden_size, self.num_class) def forward(self, input_ids, attention_mask): x = self.bert(input_ids=input_ids, attention_mask=attention_mask) x = x.last_hidden_state[:, 0, :] x = self.proj_lin(x) return x def train_step(model, train_dataloader, loss_fn, optimizer): num_iterations = len(train_dataloader) for i in range(NUM_EPOCHS): print(f"Training Epoch n° {i}") model.train() for j, batch in enumerate(train_dataloader): input = batch[:][0] attention = batch[:][1] target = batch[:][2] output = model(input.to(device), attention.to(device)) loss = loss_fn(output, target.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() run.log({"Training loss": loss}) print(f"Epoch {i+1} | step {j+1} / {num_iterations} | loss : {loss}") # Save model torch.save(model.state_dict(), SAVED_MODEL_PATH) print(f"Model saved at {SAVED_MODEL_PATH}") def eval_step( test_dataloader, loss_fn, num_class, saved_model_path=SAVED_MODEL_PATH, saved_target_cats_path=SAVED_TARGET_CAT_PATH, ): y_pred = [] y_true = [] num_iterations = len(test_dataloader) # Load the saved model saved_model = CustomBertModel(num_class) saved_model.load_state_dict( torch.load(saved_model_path, weights_only=False) ) # Explicitly set weights_only to False saved_model = saved_model.to(device) saved_model.eval() # Set the model to evaluation mode print(f"Model loaded from path :{saved_model_path}") with torch.no_grad(): for j, batch in enumerate(test_dataloader): input = batch[:][0] attention = batch[:][1] target = batch[:][2] output = saved_model(input.to(device), attention.to(device)) loss = loss_fn(output, target.to(device)) run.log({"Eval loss": loss}) print(f"Eval loss : {loss}") y_pred.extend(output.cpu().numpy().argmax(axis=1)) y_true.extend(target.cpu().numpy()) class_labels = torch.load(saved_target_cats_path, weights_only=False) true_labels = [class_labels[i] for i in y_true] pred_labels = [class_labels[i] for i in y_pred] print(f"Accuracy : {accuracy_score(true_labels, pred_labels)}") cm = confusion_matrix(true_labels, pred_labels, labels=class_labels) df_cm = pd.DataFrame(cm, index=class_labels, columns=class_labels) sns.heatmap(df_cm, annot=True, fmt="d") plt.title("Confusion Matrix for BBC News Dataset") plt.xlabel("Predicted Label") plt.ylabel("True Label") plt.show() if __name__ == "__main__": wandb.login() run = wandb.init(project="DIT-Bert-bbc-news-project") our_bert_dataset = CustomBertDataset(DS_PATH) print(f"Size of bert dataset : {len(our_bert_dataset)}") train_dataset = Subset(our_bert_dataset, range(int(len(our_bert_dataset) * 0.8))) test_dataset = Subset( our_bert_dataset, range(int(len(our_bert_dataset) * 0.8), len(our_bert_dataset)) ) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) our_bert_model = CustomBertModel(our_bert_dataset.num_class) our_bert_model = our_bert_model.to(device) loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD( filter(lambda p: p.requires_grad, our_bert_model.parameters()), lr=0.01 ) train_step(our_bert_model, train_dataloader, loss_fn, optimizer) eval_step(test_dataloader, loss_fn, our_bert_dataset.num_class)