File size: 3,820 Bytes
dec266f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from typing import Dict, List, Tuple
import numpy as np
import os
class ToxicCommentDataset(Dataset):
def __init__(self, texts: List[str], labels: np.ndarray, tokenizer: BertTokenizer, max_length: int = 128):
# Convert texts to list if it's a pandas Series
self.texts = texts.tolist() if isinstance(texts, pd.Series) else texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
text = str(self.texts[idx])
# Handle unusual line terminators
text = text.replace('\u2028', ' ').replace('\u2029', ' ') # Remove line/paragraph separators
text = ' '.join(text.splitlines()) # Normalize all newlines
label = self.labels[idx]
encoding = self.tokenizer(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.FloatTensor(label)
}
def load_toxic_data(data_path: str) -> Tuple[List[str], np.ndarray]:
"""Load and prepare the toxic comment dataset"""
try:
# Use encoding='utf-8-sig' to handle BOM if present
df = pd.read_csv(data_path, encoding='utf-8-sig', on_bad_lines='skip')
# List of toxicity categories
toxic_categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# Convert text column to list and labels to numpy array
texts = df['comment_text'].tolist()
labels = df[toxic_categories].values
return texts, labels
except Exception as e:
raise RuntimeError(f"Error loading data from {data_path}: {str(e)}")
def create_data_loaders(
texts: List[str],
labels: np.ndarray,
tokenizer: BertTokenizer,
train_ratio: float = 0.8,
batch_size: int = 32,
num_workers: int = 4 # Adjusted for Windows
) -> Tuple[DataLoader, DataLoader]:
"""Create train and validation data loaders"""
try:
# Calculate split index
dataset_size = len(texts)
train_size = int(dataset_size * train_ratio)
# Split data
train_texts = texts[:train_size]
train_labels = labels[:train_size]
val_texts = texts[train_size:]
val_labels = labels[train_size:]
# Create datasets
train_dataset = ToxicCommentDataset(train_texts, train_labels, tokenizer)
val_dataset = ToxicCommentDataset(val_texts, val_labels, tokenizer)
# Create data loaders with Windows-optimized settings
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True, # Helps with CUDA performance
persistent_workers=True # Keeps workers alive between epochs
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
return train_loader, val_loader
except Exception as e:
raise RuntimeError(f"Error creating data loaders: {str(e)}") |