File size: 10,651 Bytes
da89f1c f292cd1 da89f1c b41635a da89f1c b41635a da89f1c b42b434 f292cd1 da89f1c b41635a 37586a2 f292cd1 b41635a f292cd1 da89f1c 37586a2 db7bdc3 da89f1c 77bc910 37586a2 77bc910 f292cd1 da89f1c f292cd1 ffe97b0 f292cd1 ffe97b0 f292cd1 07a484d f292cd1 07a484d f292cd1 07a484d f292cd1 da89f1c b41635a 8035112 b41635a da89f1c a4d7cd8 da89f1c b41635a da89f1c a4d7cd8 da89f1c ae47555 da89f1c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, AutoTokenizer
import pandas as pd
import numpy as np
import logging
logger = logging.getLogger(__name__)
class DocumentDataset(Dataset):
"""
Dataset class for document classification
with improved preprocessing and batching
"""
def __init__(self, texts, labels, tokenizer_name='bert-base-uncased', max_length=512, num_classes=None):
self.texts = texts
self.labels = labels
self.num_classes = num_classes
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.max_length = max_length
if labels is None:
self.labels = np.zeros(len(texts), dtype=int)
elif type(labels) is not np.ndarray and type(labels) is not list:
# Validate labels
unique_labels = set(labels)
min_label = min(unique_labels) if unique_labels else 0
max_label = max(unique_labels) if unique_labels else 0
# Log warning if labels might be out of range
if num_classes is not None and (min_label < 0 or max_label >= num_classes):
logger.warning(f"Label Range Error: Labels must be between 0 and {num_classes-1}, "
f"but found range [{min_label}, {max_label}]")
logger.warning(f"Unique label values: {sorted(unique_labels)}")
# Fix labels by remapping them to start from 0 (some datasets might have labels starting from 1)
if min_label != 0:
logger.warning(f"Auto-correcting labels to be zero-indexed...")
label_map = {original: idx for idx, original in enumerate(sorted(unique_labels))}
self.labels = np.array([label_map[label] for label in labels])
logger.warning(f"New unique label values: {sorted(set(self.labels))}")
else:
# If labels is a list or numpy array, there are multiple label columns
# Validate each label column
labels = np.array(labels)
for i in range(labels.shape[1]):
unique_labels = set(labels[:, i])
min_label = min(unique_labels) if unique_labels else 0
max_label = max(unique_labels) if unique_labels else 0
# Log warning if labels might be out of range
if num_classes is not None and (min_label < 0 or max_label >= num_classes):
logger.warning(f"Label Range Error: Labels must be between 0 and {num_classes-1}, "
f"but found range [{min_label}, {max_label}]")
logger.warning(f"Unique label values: {sorted(unique_labels)}")
# Fix labels by remapping them to start from 0
if min_label != 0:
logger.warning(f"Auto-correcting labels to be zero-indexed...")
label_map = {original: idx for idx, original in enumerate(sorted(unique_labels))}
labels[:, i] = np.array([label_map[label] for label in labels[:, i]])
logger.warning(f"New unique label values: {sorted(set(labels[:, i]))}")
self.labels = labels
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx] if self.labels is not None else torch.tensor(0, dtype=torch.long)
# Tokenize the text with attention mask and truncation
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_length,
return_token_type_ids=True,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'token_type_ids': encoding['token_type_ids'].flatten(),
'label': torch.tensor(label, dtype=torch.long)
}
def get_text_(self, idx):
"""Get original text for a given index"""
return {
'text': self.texts[idx],
'label': self.labels[idx] if self.labels is not None else None
}
def load_data(data_path, text_col='text', label_col: str | list ='label', validation_split=0.1, test_split=0.1, seed=42):
"""
Load data from CSV/TSV and split into train, validation and test sets
"""
# Determine file format based on extension
if data_path.endswith('.csv'):
df = pd.read_csv(data_path)
elif data_path.endswith('.tsv'):
df = pd.read_csv(data_path, sep='\t')
else:
raise ValueError("Unsupported file format. Please provide CSV or TSV file.")
# If label_col is a list of columns, do the below but for each column
if isinstance(label_col, list):
labels = None
for idx, label in enumerate(label_col):
if label not in df.columns:
raise ValueError(f"Label column '{label}' not found in the dataset.")
# Convert labels to numeric if they aren't already
if not np.issubdtype(df[label].dtype, np.number):
label_map = {label: idx for idx, label in enumerate(sorted(df[label].unique()))}
df[f'label_numeric_{idx}'] = df[label].map(label_map)
if labels is None:
labels = df[f'label_numeric_{idx}'].values.reshape(-1, 1)
else:
# Extend the labels array to dim 1
labels = np.concatenate((labels, df[f'label_numeric_{idx}'].values.reshape(-1, 1)), axis=1)
# Log the mapping for reference
logger.info(f"Label mapping for column '{label}': {label_map}")
else: # Column is already numeric
current_col_labels = df[label].values # Get current column's data
# Check if labels start from 0
min_label = current_col_labels.min()
if min_label != 0:
logger.warning(f"Labels in numeric column '{label}' don't start from 0 (min={min_label}). Converting to zero-indexed...")
label_map = {lbl: idx for idx, lbl in enumerate(sorted(set(current_col_labels)))}
current_col_labels = np.array([label_map[lbl] for lbl in current_col_labels]) # Apply mapping to current column data
# Concatenate this column to the main 'labels' array
if labels is None:
# This is the first column encountered (and it's numeric)
labels = current_col_labels.reshape(-1, 1)
else:
# Append this numeric column to existing labels
labels = np.concatenate((labels, current_col_labels.reshape(-1, 1)), axis=1)
else: # In case there is only one label column
# Convert labels to numeric if they aren't already
if not np.issubdtype(df[label_col].dtype, np.number):
label_map = {label: idx for idx, label in enumerate(sorted(df[label_col].unique()))}
df['label_numeric'] = df[label_col].map(label_map)
labels = df['label_numeric'].values
logger.info(f"Label mapping: {label_map}")
else:
labels = df[label_col].values
# Check if labels start from 0
min_label = labels.min()
if min_label != 0:
logger.warning(f"Labels don't start from 0 (min={min_label}). Converting to zero-indexed...")
label_map = {label: idx for idx, label in enumerate(sorted(set(labels)))}
labels = np.array([label_map[label] for label in labels])
# Create a DataFrame with text and numeric labels
texts = df[text_col].values
# Shuffle and split the data
np.random.seed(seed)
indices = np.random.permutation(len(texts))
test_size = int(test_split * len(texts))
val_size = int(validation_split * len(texts))
train_size = len(texts) - test_size - val_size
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]
train_texts, train_labels = texts[train_indices], labels[train_indices]
val_texts, val_labels = texts[val_indices], labels[val_indices]
test_texts, test_labels = texts[test_indices], labels[test_indices]
# Log stats about the dataset
logger.info(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
# Also print the num_categories being passed
return (train_texts, train_labels), (val_texts, val_labels), (test_texts, test_labels)
def create_data_loaders(train_data, val_data, test_data, tokenizer_name='bert-base-uncased',
max_length=512, batch_size=16, num_classes=None, return_datasets=False):
"""
Create DataLoader objects for training, validation and testing
"""
train_texts, train_labels = train_data
val_texts, val_labels = val_data
test_texts, test_labels = test_data
# Create datasets
train_dataset = DocumentDataset(train_texts, train_labels, tokenizer_name, max_length, num_classes)
val_dataset = DocumentDataset(val_texts, val_labels, tokenizer_name, max_length, num_classes)
test_dataset = DocumentDataset(test_texts, test_labels, tokenizer_name, max_length, num_classes)
if return_datasets:
return train_dataset, val_dataset, test_dataset
# Create data loaders
if len(train_dataset.texts) == 0:
logger.warning("Training dataset is empty. Check your data loading and splitting.")
train_loader = None
else:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
if len(val_dataset.texts) == 0:
logger.warning("Validation dataset is empty. Check your data loading and splitting.")
val_loader = None
else:
val_loader = DataLoader(val_dataset, batch_size=batch_size)
if len(test_dataset.texts) == 0:
logger.warning("Test dataset is empty. Check your data loading and splitting.")
test_loader = None
else:
test_loader = DataLoader(test_dataset, batch_size=batch_size)
return train_loader, val_loader, test_loader |