|
|
import pandas as pd
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from transformers import BertTokenizer
|
|
|
from typing import Dict, List, Tuple
|
|
|
import numpy as np
|
|
|
import os
|
|
|
|
|
|
class ToxicCommentDataset(Dataset):
|
|
|
def __init__(self, texts: List[str], labels: np.ndarray, tokenizer: BertTokenizer, max_length: int = 128):
|
|
|
|
|
|
self.texts = texts.tolist() if isinstance(texts, pd.Series) else texts
|
|
|
self.labels = labels
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_length = max_length
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.texts)
|
|
|
|
|
|
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
|
|
|
text = str(self.texts[idx])
|
|
|
|
|
|
|
|
|
text = text.replace('\u2028', ' ').replace('\u2029', ' ')
|
|
|
text = ' '.join(text.splitlines())
|
|
|
|
|
|
label = self.labels[idx]
|
|
|
|
|
|
encoding = self.tokenizer(
|
|
|
text,
|
|
|
add_special_tokens=True,
|
|
|
max_length=self.max_length,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
return_tensors='pt'
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
'input_ids': encoding['input_ids'].flatten(),
|
|
|
'attention_mask': encoding['attention_mask'].flatten(),
|
|
|
'labels': torch.FloatTensor(label)
|
|
|
}
|
|
|
|
|
|
def load_toxic_data(data_path: str) -> Tuple[List[str], np.ndarray]:
|
|
|
"""Load and prepare the toxic comment dataset"""
|
|
|
try:
|
|
|
|
|
|
df = pd.read_csv(data_path, encoding='utf-8-sig', on_bad_lines='skip')
|
|
|
|
|
|
|
|
|
toxic_categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
|
|
|
|
|
|
|
|
|
texts = df['comment_text'].tolist()
|
|
|
labels = df[toxic_categories].values
|
|
|
|
|
|
return texts, labels
|
|
|
except Exception as e:
|
|
|
raise RuntimeError(f"Error loading data from {data_path}: {str(e)}")
|
|
|
|
|
|
def create_data_loaders(
|
|
|
texts: List[str],
|
|
|
labels: np.ndarray,
|
|
|
tokenizer: BertTokenizer,
|
|
|
train_ratio: float = 0.8,
|
|
|
batch_size: int = 32,
|
|
|
num_workers: int = 4
|
|
|
) -> Tuple[DataLoader, DataLoader]:
|
|
|
"""Create train and validation data loaders"""
|
|
|
try:
|
|
|
|
|
|
dataset_size = len(texts)
|
|
|
train_size = int(dataset_size * train_ratio)
|
|
|
|
|
|
|
|
|
train_texts = texts[:train_size]
|
|
|
train_labels = labels[:train_size]
|
|
|
val_texts = texts[train_size:]
|
|
|
val_labels = labels[train_size:]
|
|
|
|
|
|
|
|
|
train_dataset = ToxicCommentDataset(train_texts, train_labels, tokenizer)
|
|
|
val_dataset = ToxicCommentDataset(val_texts, val_labels, tokenizer)
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=True,
|
|
|
persistent_workers=True
|
|
|
)
|
|
|
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=True,
|
|
|
persistent_workers=True
|
|
|
)
|
|
|
|
|
|
return train_loader, val_loader
|
|
|
except Exception as e:
|
|
|
raise RuntimeError(f"Error creating data loaders: {str(e)}") |