fzn0x's picture
Upload folder using huggingface_hub
2086153 verified
import logging
from datetime import datetime
import re
from collections import Counter
import pandas as pd
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from transformers import (
BertConfig,
BertForSequenceClassification,
BertTokenizer,
Trainer,
TrainingArguments,
EarlyStoppingCallback,
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
confusion_matrix,
)
from sklearn.utils.class_weight import compute_class_weight
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained("bert-base-uncased", num_labels=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class WeightedBertForSequenceClassification(BertForSequenceClassification):
def __init__(self, config, class_weights):
super().__init__(config)
self.class_weights = class_weights
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=None, **kwargs)
logits = outputs.logits
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return {"loss": loss, "logits": logits}
class SMSClassificationDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = torch.tensor(labels, dtype=torch.long)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item["labels"] = self.labels[idx]
return item
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = torch.argmax(torch.tensor(logits), dim=1)
acc = accuracy_score(labels, predictions)
precision = precision_score(labels, predictions, average="weighted", zero_division=0)
recall = recall_score(labels, predictions, average="weighted")
f1 = f1_score(labels, predictions, average='weighted')
cm = confusion_matrix(labels, predictions)
print("Confusion Matrix:\n", cm)
return {
'accuracy': acc,
'precision': precision,
'recall': recall,
'f1': f1
}
def train():
df = pd.read_csv('data/spam.csv', encoding='iso-8859-1')[['label', 'text']]
label_mapping = {'spam': 1, 'ham': 0}
df['label'] = df['label'].map(label_mapping)
train_texts, val_texts, train_labels, val_labels = train_test_split(
df['text'].tolist(), df['label'].tolist(), test_size=0.25, random_state=42)
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.unique(train_labels),
y=train_labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
model = WeightedBertForSequenceClassification(config, class_weights=class_weights)
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
if "transformers" in logger.name.lower():
logger.setLevel(logging.ERROR)
model.load_state_dict(BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, use_safetensors=True, return_dict=False, attn_implementation="sdpa").state_dict(), strict=False)
model.to(device)
train_encodings = tokenizer(train_texts, truncation=True, padding=True, return_tensors="pt")
val_encodings = tokenizer(val_texts, truncation=True, padding=True, return_tensors="pt")
train_dataset = SMSClassificationDataset(train_encodings, train_labels)
val_dataset = SMSClassificationDataset(val_encodings, val_labels)
training_args = TrainingArguments(
output_dir='./models/pretrained',
num_train_epochs=5,
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
eval_strategy="epoch",
report_to="none",
save_total_limit=1,
load_best_model_at_end=True,
save_strategy="epoch",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
trainer.train()
logs = trainer.state.log_history
df_logs = pd.DataFrame(logs)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
df_logs.to_csv(f"logs/training_logs_{timestamp}.csv", index=False)
tokenizer.save_pretrained('./models/pretrained')
model.save_pretrained('./models/pretrained')
if __name__ == "__main__":
train()